# Load Libraries

In [None]:
import warnings
warnings.filterwarnings('ignore')

import ast
import joblib
import joblib
import librosa
import librosa.display

from IPython.display import Audio, display

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from scipy.signal import resample
from tqdm import tqdm
from collections import Counter
from pprint import pprint
%matplotlib inline

from sklearn import metrics

from transformers import WhisperFeatureExtractor, WhisperModel, WhisperProcessor
import torch

In [None]:
from functions.functions_model import get_CNN_model, get_NN_model
from functions.functions_features import extract_features, extract_features_CNN
from functions.functions_whisper_model import SpeechClassificationDataset, SpeechClassifier, train, evaluate
from functions.functions_onset_metric import (
    binary_to_intervals, intervals_to_binary, 
    merge_close_intervals, 
    compute_overlap, 
    total_overlap_duration,
    subtract_intervals, 
    total_duration, 
    overlap_calculations, 
    interval_intersection,
    total_interval_duration,
    compute_overlap, 
    inverse_intervals,
    evaluate_intervals,
    remove_isolated_detection,
    fill_short_gaps,
    remove_mean_threshold,
    safe_divide,
    )

# User Defined Functions

In [None]:
# Get columns
columns_features = [
    'mean', 'variance', 'std_dev', 'max_value', 'min_value', 'rms',
    'skewness', 'kurtosis', 'median', 'range_val', 'iqr',
    'zcr', 'energy', 'rmse', 'entropy',
    'spectral_centroid', 'spectral_bandwidth', 'spectral_contrast',
    'spectral_flatness', 'spectral_rolloff', 'chroma_stft',
    ]

# MFCC
for i in range(1, 21):
    columns_features.append(f'mfcc_mean_{i}_mean')
    columns_features.append(f'mfcc_{i}_std')

column_drop = ['mean', 'variance', 'std_dev', 'skewness', 'kurtosis', 'median', 'range_val', 'iqr']

# Evaluation metrics

In [None]:
intervals = [[0.7, 0.8], [1.0, 1.2], [5, 5.5], [6, 7], [10, 12]]
print(merge_close_intervals(intervals, max_gap=0.5))

In [None]:
# # Example intervals
# predicted_intervals = [[1, 4], [8, 11], [40, 45], [46, 50]]
# ground_truth_intervals = [[2, 5], [9, 12], [39, 44], [50, 54]]

# metrics_overlap = overlap_calculations(predicted_intervals, ground_truth_intervals)
# for key, value in metrics_overlap.items():
#     print(f"{key}: {value}")

In [None]:
labels = [[1, 3], [5, 7]]
duration = 10
print(inverse_intervals(labels, duration))

In [None]:
# Example intervals
predicted_intervals = [[1, 4], [7, 10]]
ground_truth_intervals = [[2, 5], [8, 9]]

metrics_overlap = overlap_calculations(predicted_intervals, ground_truth_intervals)
for key, value in metrics_overlap.items():
    print(f"{key}: {value}")

In [None]:
# Example usage
# predicted_intervals = [[1, 4], [8, 11], [40, 45], [46, 50]]
# ground_truth_intervals = [[2, 5], [9, 12], [39, 44], [50, 54]]

# Example usage
predicted_intervals = [[1, 5], [15, 20]]
ground_truth_intervals = [[0, 20]]
threshold_overlap = 0.1  # 10% overlap required

metrics = evaluate_intervals(predicted_intervals, ground_truth_intervals, overlap_threshold=threshold_overlap)
for key, value in metrics.items():
    print(f"{key}: {value}")

# Load Data

In [None]:
list_dataset_name = [
    'coswara', 
    'coughvid', 
    'esc50', 
    'fsdkaggle', 
    'virufy',
    ]

#################################################################################
# Default values
#################################################################################
segment_length = 0.3 # [0.1, 0.2, 0.3, 0.5, 0.7, 1] Segment split window length
overlap = 0 # Overlap % (in fraction)
step_size = segment_length * (1 - overlap)

threshold_proba = 0.5 # Probability threshold for model prediction
threshold_amplitude_mean = 0.005 # Minimum amplitude before drop off
threshold_max_gap = 0.2 # Minimum gap between 2 interval before combination

threshold_overlap = 0.1 # Minimum % is required for overlap to be TP

dataset_str = '_'.join(list_dataset_name)

In [None]:
# model_name = 'LR' # RF, GB, LR
# model_name = 'RF', # GB, LR
# model_name = 'Keras_NN'
# model_name = 'CNN'
model_name = 'Whisper'

if model_name in ['LR', 'DT', 'RF', 'SVM', 'KNN', 'NB', 'NN', 'GB']:
    #################################################################################
    # ML
    #################################################################################
    print(f'ML: {model_name}')
    path_model_save = f'Results_Onset/Model_Onset/{dataset_str}/{model_name}_{segment_length}s/'
    model_filename = f"{path_model_save}model_1.joblib"
    scaler_filename = f"{path_model_save}scaler__1.joblib"
    model = joblib.load(model_filename)
    scaler = joblib.load(scaler_filename)

elif model_name in ['NN']:
    ##################################################################################
    # NN
    ##################################################################################
    path_model_save = f'Results_Onset/Model_Onset/{dataset_str}/Keras_NN_{segment_length}s/'
    model = get_NN_model(53)
    model.load_weights(f'{path_model_save}model_1.h5')
    scaler_filename = f"{path_model_save}scaler__1.joblib"
    scaler = joblib.load(scaler_filename)

elif model_name in ['CNN']:
    #################################################################################
    # CNN
    #################################################################################    
    dimension_dictionary = {
        0.1: 5,
        0.2: 9,
        0.3: 13,
        0.5: 22,
        0.7: 31,
        1: 22,
    }
    
    dim_first = 128
    input_shape = (dim_first, dimension_dictionary[segment_length], 1)
    model = get_CNN_model(input_shape)
    
    path_model_save = f'Results_Onset/Model_CNN_Onset/{dataset_str}/{segment_length}s/'
    model.load_weights(f'{path_model_save}model_CNN_{segment_length}s_1.h5')
    scaler_filename = f"{path_model_save}scaler_pipeline_CNN_{segment_length}s_1.joblib"
    scaler = joblib.load(scaler_filename)

elif model_name in ['Whisper']:
    ##################################################################################
    # Whisper
    ##################################################################################
    path_model_save = f'Results_Onset/Model_Whisper_Onset/{dataset_str}/whisper_best_model_{segment_length}s.pt'
    
    model_checkpoint = "openai/whisper-base"
    processor = WhisperProcessor.from_pretrained(model_checkpoint)
    whisper_model = WhisperModel.from_pretrained("openai/whisper-base")
    encoder = whisper_model.encoder  # this is the encoder module
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    num_labels = 2
    state_dict = torch.load(path_model_save)
    model = SpeechClassifier(num_labels, encoder).to(device)
    model.load_state_dict(state_dict)


In [None]:
#################################################################################
# Main
#################################################################################
results_all = []
for dataset_name in list_dataset_name:
    print(dataset_name)

    df_all = pd.read_csv(f'Results_Onset/Data_Onset/Annotation/data_summary_{dataset_name}_{segment_length}s_onset_label.csv')
    df_all['label_onset'] = df_all['label_onset'].apply(ast.literal_eval)
    df_all['label_event'] = df_all['label_event'].apply(ast.literal_eval)
    
    df_all = df_all[df_all['label']==0].reset_index(drop=True)
    # df_all = df_all[df_all['label']==1].reset_index(drop=True)
    
    total_len = len(df_all)
    if total_len > 1:
        total_len = 20

    for i in tqdm(range(total_len)):

        filepath = df_all['filepath'][i] # Audio path
        dataset = df_all['dataset'][i] # Dataset name
        filename = df_all['filename'][i]
        
        label = df_all['label'][i]
        age = df_all['age'][i]
        gender = df_all['gender'][i]
        status = df_all['status'][i]
        
        # print(f'{dataset} {filename} {label}')

        # Load data
        (y, sr) = librosa.load(filepath) # mono=True
        duration = librosa.get_duration(y=y, sr=sr)

        # Get time interval
        time_intervals = np.arange(0, duration - segment_length + step_size, step_size)
        print('Duration:', duration)
        print('Time:', time_intervals)

        segment_samples = int(segment_length * sr)
        step = segment_samples - int(overlap * sr)

        try:
            label_onset = df_all['label_onset'][i]
            label_onset = list(label_onset)
        except:
            label_onset = [0 for i in range(len(time_intervals))]

        label_pred = []
        label_pred_proba = []
        list_threshold_amplitude_mean = []

        #################################################################################
        # Looping through segments
        #################################################################################
        if len(label_onset) != 0:
            for j in range(0, len(label_onset)):
                start_sample = j * segment_samples
                segment = y[start_sample:start_sample + segment_samples]
    
                if len(segment) < segment_samples:
                    padding = np.zeros(segment_samples - len(segment))
                    segment = np.concatenate((segment, padding))

                mean = np.mean(np.abs(segment))
                
                #################################################################################
                # Extract Features and predict
                #################################################################################
                if mean <= threshold_amplitude_mean:
                    list_threshold_amplitude_mean.append(0)
                else:
                    list_threshold_amplitude_mean.append(1)
                    
                #################################################################################
                # Get probabilities
                #################################################################################
                # Whisper
                if model_name == 'Whisper':                    
                    # Calculate new number of samples
                    new_length = int(len(segment) * 16000 / 22500)
                    segment = resample(segment, new_length)
                    
                    segment_np = segment.astype(np.float32)
                    inputs = processor(segment_np, sampling_rate=16000, return_tensors="pt")
                    input_features = inputs.input_features.to(device)
                
                    with torch.no_grad():
                        logits = model(input_features)
                        probabilities = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()[0]
                        pred = np.argmax(probabilities)
                        pred_proba = probabilities

                # CNN
                elif model_name == 'CNN':
                    result_row = extract_features_CNN(segment, sr, segment_length)
                    result_row = np.array(result_row[0]).reshape((1, dim_first,  dimension_dictionary[segment_length]))
                    result_row = result_row[..., np.newaxis]

                    result_row = np.array(result_row)
                    result_row = np.nan_to_num(result_row, nan=0)

                    pred_proba = model.predict(result_row, verbose=0)[0]
                    pred = np.argmax(pred_proba)
                    # pred = pred_proba[1]

                # ML / NN
                elif model_name in ['LR', 'DT', 'RF', 'SVM', 'KNN', 'NB', 'NN', 'GB', 'Keras_NN']:
                    result_row = extract_features(segment, sr)
                    result_row = pd.DataFrame([result_row], columns=columns_features)
                    for col in column_drop:
                        if col in result_row.columns:
                            result_row = result_row.drop([col], axis=1)
                    result_row = np.array(result_row)[0]
                    result_row = np.nan_to_num(result_row, nan=0)
                    result_row = scaler.transform(result_row.reshape(1, -1))

                    # NN
                    if model_name in ['Keras_NN']:
                        pred_proba = model.predict(result_row)
                        pred = np.argmax(pred_proba, axis=1)

                    # ML
                    else:
                        pred_proba = model.predict_proba(result_row)
                        # label_pred = np.argmax(pred_proba, axis=1)
                        pred = (pred_proba[:, 1] >= threshold_proba).astype(int)[0]
                        pred_proba = list(pred_proba[0])

                label_pred.append(pred)
                label_pred_proba.append(pred_proba)

            label_pred_proba = np.array(label_pred_proba)
            print('True:', label_onset)
            print('Pred:', label_pred)

            # Preprocessing
            label_pred = remove_mean_threshold(label_pred, list_threshold_amplitude_mean)
            label_pred = fill_short_gaps(label_pred, threshold=1)
            label_pred = remove_isolated_detection(label_pred, max_length_sequence=1) # Remove isolated detection

            # Convert into intervals
            label_onset_interval = binary_to_intervals(label_onset, time_step=segment_length)
            label_pred_interval = binary_to_intervals(label_pred, time_step=segment_length)
            
            label_pred_interval = merge_close_intervals(label_pred_interval, threshold_max_gap) # Merge close intervals
            
            # Convert back
            label_pred = intervals_to_binary(label_pred_interval, len(label_pred))
            label_onset = intervals_to_binary(label_onset_interval, len(label_onset))

            # Get non-cough intervals
            label_onset_inv_interval = inverse_intervals(label_onset_interval, duration)
            label_pred_inv_interval = inverse_intervals(label_pred_interval, duration)

            # Get cough and non-cough intersects
            label_cough_interval_intersect = interval_intersection(label_onset_interval, label_pred_interval)
            label_non_cough_interval_intersect = interval_intersection(label_onset_inv_interval, label_pred_inv_interval)

            # Get intersection duration
            total_cough_intersect_duration = round(total_interval_duration(label_cough_interval_intersect), 1)
            total_non_cough_intersect_duration = round(total_interval_duration(label_non_cough_interval_intersect), 1)

            total_non_cough_duration = round(total_interval_duration(label_onset_inv_interval), 1)
            total_cough_duration = round(total_interval_duration(label_onset_interval), 1)
            total_pred_duration = round(total_interval_duration(label_pred_interval), 1)

            SENd = round(safe_divide(total_cough_intersect_duration, total_cough_duration), 3)
            SPEd = round(safe_divide(total_non_cough_intersect_duration, total_non_cough_duration), 3)
            PREd = round(safe_divide(total_cough_intersect_duration, total_pred_duration), 3)
            F1d = round(2 * safe_divide((PREd * SENd), (PREd + SENd)), 3)
            
            audio_information = {
                'filepath': filepath,
                'dataset': dataset,
                'filename': filename,
                'label': label,
                'age': age,
                'gender': gender,
                'status': status,
                'segment_length': segment_length,
                'overlap': overlap,
                'label_onset': label_onset,
                'label_pred': label_pred,
                'len': len(label_onset),
                'duration': len(label_onset)*segment_length,
                'threshold_proba': threshold_proba,
                'threshold_amplitude_mean': threshold_amplitude_mean,
                'threshold_overlap': threshold_overlap,
                'threshold_max_gap': threshold_max_gap,
                'label_onset_interval': label_onset_interval,
                'label_pred_interval': label_pred_interval,
                'label_onset_inv_interval': label_onset_inv_interval,
                'label_pred_inv_interval': label_pred_inv_interval,
                'label_cough_interval_intersect': label_cough_interval_intersect,
                'label_non_cough_interval_intersect': label_non_cough_interval_intersect,
                'total_cough_intersect_duration': total_cough_intersect_duration,
                'total_non_cough_intersect_duration': total_non_cough_intersect_duration,
                'total_cough_duration': total_cough_duration,
                'total_non_cough_duration': total_non_cough_duration,
                'total_pred_duration': total_pred_duration,
                'SENd': SENd,
                'SPEd': SPEd,
                'PREd': PREd,
                'F1d': F1d,
            }
            
            metrics = evaluate_intervals(
                label_pred_interval, 
                label_onset_interval, 
                overlap_threshold=threshold_overlap)
    
    
            metrices_combined = {**audio_information, **metrics}
            # for key, value in metrices_combined.items():
            #     print(f"{key}: {value}")

            list_print = [
                'label_onset_interval', 'label_pred_interval',
                # 'label_onset_inv_interval', 'label_pred_inv_interval',
                # 'label_cough_intersect', 'label_non_cough_intersect',
                # 'TP', 'FP', 'FN', 'PRE', 'REC', 'F1',
                'total_cough_intersect_duration', 
                'total_non_cough_intersect_duration',
                'total_cough_duration', 'total_non_cough_duration', 'total_pred_duration',
                
                'SENd', 'SPEd', 'PREd', 'F1d',
                ]
            for key in list_print:
                print(f'{key}: {metrices_combined[key]}')

            results_all.append(metrices_combined)
            
            # Create a figure with subplots
            fig, axs = plt.subplots(3, 1, figsize=(6, 5), 
                                    sharex=True
                                   )
            
            # Plot waveform
            librosa.display.waveshow(y, sr=sr, ax=axs[0])
            axs[0].set_title('Audio Waveform')
            axs[0].set_xlabel('Time (s)')
            axs[0].set_ylabel('Amplitude')
            axs[0].grid(True)
            axs[0].minorticks_on()  # Enable minor ticks for finer control
    
            # Plot waveform
            axs[1].plot(time_intervals, label_pred_proba[:, 1], color='green', marker='o')
            axs[1].axhline(y=threshold_proba, linestyle=':', color='black')  # dotted horizontal line at y=0.5
            axs[1].set_ylim(0, 1.1)
            axs[1].set_title('Predict Probability')
            axs[1].set_ylabel('Values')
            

            def plot_intervals(ax, intervals, y, color, label=None):
                for start, end in intervals:
                    ax.barh(y=y, width=end - start, left=start, height=0.2, color=color, label=label)
                    
            # Function to compute intersection of two interval lists
            def compute_intersection(intervals1, intervals2):
                result = []
                for start1, end1 in intervals1:
                    for start2, end2 in intervals2:
                        start = max(start1, start2)
                        end = min(end1, end2)
                        if start < end:  # valid overlap
                            result.append([start, end])
                return result
                
            # Plot onset (red), pred (blue), and intersection (purple)
            plot_intervals(axs[2], label_onset_interval, y=0.8, color='red', label='True')
            plot_intervals(axs[2], label_pred_interval, y=0.5, color='green', label='Predict')
            plot_intervals(axs[2], label_cough_interval_intersect, y=0.2, color='black', label='Intersect')

            axs[2].set_title('Prediction vs True Labels')
            axs[2].set_ylim(0, 1)
            axs[2].set_yticks([0.2, 0.5, 0.8])
            axs[2].set_yticklabels(['Intersect', 'Predict', 'True'])
            axs[2].set_xlabel('Time (s)')

            
            plt.tight_layout()
            plt.show()
    
            # Play audio
            display(Audio(data=y, rate=sr))

results_all = pd.DataFrame(results_all, columns=metrices_combined.keys())
results_all.to_csv(f'Results_Onset/results_onset_metrics_{segment_length}s.csv', index=False)

In [None]:
bar_width

In [None]:
results_all = pd.DataFrame(results_all, columns=metrices_combined.keys())
results_all.to_csv('Results_Onset/results_onset_metrics.csv', index=False)

In [None]:
results_all