In [None]:
import json
import os
import pickle
import re
import sys
sys.path.append('../')
from glob import glob

import editdistance
import numpy as np
import pandas as pd
from pyannote.audio import Model
from pyannote.audio.pipelines import VoiceActivityDetection
from pyannote.core import notebook, Segment
from pydub import AudioSegment
from tqdm import tqdm
from whisper_model import WhisperASR
import matplotlib.pyplot as plt

import scipy.io.wavfile as wav
import scipy.signal as signal

import tempfile




import torch
torch.set_num_threads(1)

from IPython.display import Audio
from pprint import pprint

import os
import sys
sys.path.append('../')

import psycopg2
import wave
from pydub import AudioSegment
from tqdm import tqdm
from dotenv import load_dotenv

# Assuming custom_vad_function is a function you have that takes a filename and returns new start and end trim times
from src.utils.audio import CustomVAD, trim_audio

# Load environment variables
load_dotenv("../vars.env")

# Database credentials
db_host = os.getenv("POSTGRES_HOST")
db_name = os.getenv("POSTGRES_DB")
db_user = os.getenv("POSTGRES_USER")
db_password = os.getenv("POSTGRES_PWD")

# Establish a database connection
conn = psycopg2.connect(host=db_host, database=db_name, user=db_user, password=db_password)
cur = conn.cursor()

# Retrieve all datasets
cur.execute("SELECT id, name FROM dataset;")
datasets = cur.fetchall()



In [None]:
def get_samples_filepath(filename):
    try:
        # Get samples filepath
        cur.execute(
            f"SELECT local_path FROM sample WHERE filename = %s AND is_selected_for_delivery=True",
            (filename,),
        )
        samples_filepath = cur.fetchone()[0]
        return samples_filepath
    except psycopg2.Error as e:
        # Rollback the transaction if there is a database error
        conn.rollback()
        print(f"An error occurred: {e}")
        return None
    finally:
        # Optionally, you can reset the transaction here if you want to continue using the cursor
        conn.commit()


def get_100_samples(dataset_id):
    # for dataset_id, dataset_name in datasets:
    try:
        # Retrieve all samples that are selected for delivery from the current dataset
        cur.execute(
            """
            SELECT local_path
            FROM sample
            WHERE is_selected_for_delivery = TRUE AND
                    dataset_id = %s;
        """,
            (dataset_id,),
        )
        samples = cur.fetchall()
        # randomly select 100 
        # convert to a list 
        samples = [sample[0] for sample in samples]
        samples = np.random.choice(samples, 100)
        return samples
    except psycopg2.Error as e:
        # Rollback the transaction if there is a database error
        conn.rollback()
        print(f"An error occurred: {e}")
        return None
    finally:
        # Optionally, you can reset the transaction here if you want to continue using the cursor
        conn.commit()





In [None]:
# files = [
#     "ES00000541.wav",
#     "ES00002307.wav",
#     "ES00005896.wav",
#     "ES00015863.wav",
#     "ES00017186.wav",
#     "ES00032124.wav",
#     "DE00008847.wav",
#     "DE00051117.wav",

#     ]
# FILES = [get_samples_filepath(f) for f in files]


In [None]:
dataset_id = 52
FILES = get_100_samples(dataset_id)

In [None]:
FILES

In [None]:
from src.utils.audio import CustomVAD

my_custom_vad = CustomVAD(
    pyannote_model_path="pyannote/segmentation", silero_model_path="snakers4/silero-vad"
)


In [None]:
my_custom_vad.set_energy_threshold(3_000_000)
# run for selected files and plot
for file in FILES:
    response = my_custom_vad.process_file(file)

    # load the waveform
    waveform = response["resampled_waveform"]
    # Calculate time vector
    num_samples = len(waveform)
    duration = num_samples / my_custom_vad.SAMPLING_RATE
    time_vector = np.linspace(0, duration, num=num_samples)

    # plot
    fig, ax = plt.subplots(figsize=(20, 5), dpi=50)
    ax.plot(time_vector, waveform)
    ax.set_ylim(
        [waveform.min(), waveform.max()]
    )  # Ensure y-axis is scaled to waveform amplitude
    ax.set_xlim([0, duration])  # Ensure x-axis is scaled to audio duration

    # Add labels and title
    ax.set_xlabel("Time (seconds)")
    ax.set_ylabel("Amplitude")
    ax.set_title(os.path.basename(file))

    # Highlight segments using axvspan
    ax.axvspan(
        response["pyannote_segment"][0],
        response["pyannote_segment"][1],
        color="green",
        alpha=0.3,
        label="PYANNOTE",
    )
    ax.axvspan(
        response["silero_segment"][0],
        response["silero_segment"][1],
        color="red",
        alpha=0.3,
        label="SILERO-VAD",
    )
    ax.axvspan(
        response["custom_segment"][0],
        response["custom_segment"][1],
        color="blue",
        alpha=0.3,
        label="MY_CUSTOM_VAD",
    )

    # Add legend
    ax.legend()

    # Display the plot
    plt.show()