In [None]:
import librosa
import matplotlib.pyplot as plt
from utils.counting import *
import numpy as np
import os
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go


## Anotated Data against Onset Detection

In [None]:
AUDIO_DIR = "path/to/detected/wav/trill/segments/"
DATA_DIR = "path/to/annotation/file/"

trill_rate_annotation_file = "annotations_trills_with_pred.csv"
df_annotation_rhythm = pd.read_csv(os.path.join(DATA_DIR, trill_rate_annotation_file))
df_annotation_trills = df_annotation_rhythm[df_annotation_rhythm["num_trills"] > 0]

hop_length = 64
n_fft = 256
display_spec = False

In [15]:
for index, row in df_annotation_trills.iterrows():
    filename = row["file_name_radical"].split('.')[0]+f"_seg{row['segment_id']}"+".wav"
    audio_path = os.path.join(AUDIO_DIR, filename) #file_name_radical
    if not os.path.exists(audio_path):
        t_start = row["trill_t_start_pred"]
        if np.isnan(t_start):
            continue
        print(f"Missing file: {audio_path}")
        continue

    y, sr = librosa.load(audio_path, sr=None, mono=True)

    rate_fft, env, freqs_env, fft_env = trill_rate_detection_am2(y, sample_rate=sr)

    fs_env = sr / hop_length

    rate_robust, ac, lags, peak_lags = trill_rate_robust_fixed(env, fs_env, rate_fft, debug=False)

    trill_duration = row["trill_duration"]
    trill_rate_true = 0 if trill_duration == 0 else row["num_trills"] / trill_duration

    df_annotation_trills.loc[index, "trill_rate_detected_onset"] = rate_robust
    df_annotation_trills.loc[index, "trill_rate_true"] = trill_rate_true

    error_onset = rate_robust - trill_rate_true

    if display_spec :

        if rate_robust < trill_rate_true * 0.6:
            y, sr = librosa.load(audio_path, sr=None, mono=True)
            S = np.abs(librosa.stft(y, n_fft=n_fft, hop_length=hop_length))
            S_db = librosa.amplitude_to_db(S, ref=np.max)

            times = librosa.frames_to_time(np.arange(S.shape[1]), sr=sr, hop_length=hop_length)
            freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft)

            t_env = np.arange(len(env)) / fs_env

            fig, axes = plt.subplots(3, 1, figsize=(10, 8))

            # --- Spectrogram ---
            axes[0].imshow(
                S_db,
                origin="lower",
                aspect="auto",
                extent=[times[0], times[-1], freqs[0], freqs[-1]],
                cmap="magma"
            )
            axes[0].set_ylabel("Frequency (Hz)")
            axes[0].set_title(
                f"{filename} | true={trill_rate_true:.2f} Hz | detected={rate_robust:.2f} Hz"
            )

            # --- Envelope ---
            axes[1].plot(t_env, env, color="black")
            axes[1].set_ylabel("Envelope")
            axes[1].set_xlabel("Time (s)")
            axes[1].set_title("Spectral envelope")

            # --- Autocorrelation ---
            axes[2].plot(lags, ac, color="black", label="Autocorrelation")

            if len(peak_lags) > 0:
                axes[2].scatter(
                    peak_lags,
                    ac[np.searchsorted(lags, peak_lags)],
                    color="red",
                    zorder=3,
                    label="Detected peaks"
                )

            if trill_rate_true > 0:
                axes[2].axvline(
                    1 / trill_rate_true,
                    color="green",
                    linestyle="--",
                    label="True period"
                )

            axes[2].axvline(
                1 / rate_robust,
                color="red",
                linestyle="--",
                label="Detected period"
            )

            axes[2].set_xlim(0, 0.3)
            axes[2].set_xlabel("Lag (s)")
            axes[2].set_ylabel("Autocorrelation")
            axes[2].legend()

            plt.tight_layout()
            plt.show()

#### Determine faulty spectrograms

In [None]:
df_annotation_trills["error_onset"] = df_annotation_trills["trill_rate_true"] - df_annotation_trills["trill_rate_detected_onset"]

# df_faulty = df_annotation_trills[df_annotation_trills["error_onset"].abs() > 5]
df_faulty = df_annotation_trills[df_annotation_trills["error_onset"].abs() > df_annotation_trills["trill_rate_true"]*0.8]
# df_faulty = df_faulty[df_faulty["trill_rate_true"] > 30]  # focus on segments with enough trills to be meaningful

print(f"Number of segments with >80% error: {len(df_faulty)}")
print(f"Proportion of segments with >80% error: {len(df_faulty)/len(df_annotation_trills):.2%}")

for index, row in df_faulty.iterrows():
    filename = row["file_name_radical"].split('.')[0]+f"_seg{row['segment_id']}"+".wav"
    audio_path = os.path.join(AUDIO_DIR, filename) #file_name_radical
    if not os.path.exists(audio_path):
        print(f"Missing file: {audio_path}")
        continue
    if True: #index % 5 == 0:
        y, sr = librosa.load(audio_path, sr=None, mono=True)

        plt.figure(figsize=(10, 4))
        S_db = librosa.amplitude_to_db(np.abs(librosa.stft(y, n_fft=256, hop_length=64)), ref=np.max)
        freqs = librosa.fft_frequencies(sr=sr)
        times = librosa.frames_to_time(np.arange(S_db.shape[1]), sr=sr, hop_length=64)

        plt.imshow(S_db, origin="lower", aspect="auto",
                extent=[0, times[-1], freqs[0], freqs[-1]],
                cmap="magma")
        plt.colorbar(label="Amplitude (dB)")
        plt.title(f"File {filename} - Trill Detection Error: {row['error_onset']:.2f} trills/sec, True: {row['trill_rate_true']:.2f}, Detected: {row['trill_rate_detected_onset']:.2f}")
        # plt.scatter(onset_times, np.full_like(onset_times, 4000)

### Plots

In [None]:
df_plot = df_annotation_trills.dropna(
    subset=["trill_rate_true", "trill_rate_detected_onset"]
)

# df_plot = df_plot[(df_plot["trill_rate"] >= 2) & (df_plot["gmm_cluster_label"].isin(["Fast"])) & (df_plot["trill_rate"] <= 40)]
df_plot = df_plot[(df_plot["trill_rate_true"] >= 2)] 

def linear_regression(x, y):
    slope, intercept = np.polyfit(x, y, 1)
    y_pred = slope * x + intercept
    ss_res = np.sum((y - y_pred) ** 2)
    ss_tot = np.sum((y - np.mean(y)) ** 2)
    r2 = 1 - ss_res / ss_tot
    return slope, intercept, r2

def add_regression_line(fig, df, name, color):
    x = df["trill_rate_true"].values
    y = df["trill_rate_detected_onset"].values

    slope, intercept, r2 = linear_regression(x, y)

    x_line = np.linspace(x.min(), x.max(), 100)
    y_line = slope * x_line + intercept

    fig.add_trace(
        go.Scatter(
            x=x_line,
            y=y_line,
            mode="lines",
            line=dict(color=color, width=3),
            name=f"{name} (R²={r2:.2f})"
        )
    )




x = df_plot["trill_rate_true"].values
y = df_plot["trill_rate_detected_onset"].values

# Régression linéaire
slope, intercept = np.polyfit(x, y, 1)

# Prédiction
y_pred = slope * x + intercept

# R²
ss_res = np.sum((y - y_pred) ** 2)
ss_tot = np.sum((y - np.mean(y)) ** 2)
r2 = 1 - ss_res / ss_tot



fig = px.scatter(
    df_plot,
    x="trill_rate_true",
    y="trill_rate_detected_onset",
    hover_name="file_name_radical",
    # color="cluster_name",
    # hover_data=["species_x", "family", "num_trills", "annotator"],
    labels={
        "trill_rate_true": "Trill rate (Hz)",
        "trill_rate_detected_onset": "trill_rate_detected_onset (Hz)",
        "cluster_name": "Cluster"
    },
    title="trill_rate_detected_onset vs trill rate",
    opacity=0.6
)

# Droite de régression
x_line = np.linspace(x.min(), x.max(), 200)
y_line = slope * x_line + intercept

fig.add_trace(
    go.Scatter(
        x=x_line,
        y=y_line,
        mode="lines",
        name="Linear regression",
        line=dict(color="black", width=2)
    )
)

equation_text = (
    f"y = {slope:.2f} x + {intercept:.1f}<br>"
    f"R² = {r2:.3f}"
)

fig.add_annotation(
    x=0.05,
    y=0.95,
    xref="paper",
    yref="paper",
    text=equation_text,
    showarrow=False,
    align="left",
    font=dict(size=14),
    bgcolor="rgba(255,255,255,0.7)"
)


fig.update_layout(
    width=800,
    height=800
)
fig.update_xaxes(range=[0, 60])
fig.update_yaxes(scaleanchor="x", scaleratio=1)

fig.show()


