In [1]:
import hume_wsds as wsds
import math
from IPython.display import display, Audio, Markdown, HTML
import pandas as pd
from tqdm import tqdm

In [65]:
class WSGetDuration:
    def __init__(self, sample_rate=16000):
        self.sample_rate = sample_rate
        self.parent_dataset = None
        self.shard_name = None

    def get_sample(self, _column, offset):
        audio = self.parent_dataset.get_sample(self.shard_name, "audio", offset)
        return audio.tend - audio.tstart 

    @classmethod
    def from_link(cls, link, source_dataset, parent_dataset, shard_name):
        inst = cls()
        inst.parent_dataset = parent_dataset
        inst.shard_name = shard_name
        return inst
        
class WSFilterBoolean:
    def __init__(self, snr_thresh=10.0, cps_thresh=20.0, duration_thresh=1, text_key='transcription_wslang_raw.txt', sample_rate=16000):
        self.snr_thresh = snr_thresh
        self.cps_thresh = cps_thresh
        self.duration_thresh = duration_thresh
        self.sample_rate = sample_rate
        self.text_key = text_key
        self.parent_dataset = None
        self.shard_name = None

    def get_sample(self, _column, offset):
        snr = float(self.parent_dataset.get_sample(self.shard_name, "snr", offset))
        txt = self.parent_dataset.get_sample(self.shard_name, self.text_key, offset)

        try:
            duration_sec = float(self.parent_dataset.get_sample(self.shard_name, "duration", offset))
        except KeyError:
            audio = (
                self.parent_dataset.get_sample(self.shard_name, "audio", offset)
                .load(self.sample_rate)
                .numpy()
            )
            duration_sec = audio.shape[-1] / self.sample_rate

        cps = len(txt) / duration_sec if duration_sec > 0 else float("inf")
        return (snr > self.snr_thresh) and (cps < self.cps_thresh) and (duration_sec  > self.duration_thresh)

    @classmethod
    def from_link(cls, link, source_dataset, parent_dataset, shard_name):
        config = link.get("config", {})
        inst = cls(
            snr_thresh=config.get("snr_thresh", 10.0),
            cps_thresh=config.get("cps_thresh", 20.0),
            duration_thresh=config.get("duration_thresh", 1.0),
            text_key=config.get("text_key", "transcription_wslang_raw.txt"),
        )
        inst.parent_dataset = parent_dataset
        inst.shard_name = shard_name
        return inst
  
ds = wsds.WSDataset("/mnt/weka/data-wsds/fb_ears/v4-vad_ws")
text_key = 'transcription_wslang_raw.txt'

ds.add_computed(
    "duration",
    dataset_dir="../source",
    loader=WSGetDuration
)

ds.add_computed(
    "passed-filter",
    dataset_dir="../source",
    loader=WSFilterBoolean,  # <- THIS IS A CLASS, not a string
    config=dict(
        snr_thresh=10.0,
        cps_thresh=20.0,
        duration_thresh=0.5,
        text_key=text_key
    )
)

In [66]:
rows = []
MAX_SAMPLES = 500

for i, sample in tqdm(enumerate(ds)):

    if not sample["passed-filter"]:
        continue

    snr = float(sample["snr"])
    txt = sample[text_key]
    duration_sec = float(sample["duration"])
    cps = len(txt) / duration_sec if duration_sec > 0 else 0

    audio = sample["audio"].load(16000).numpy()
    audio_html = Audio(audio[:,-32000:], rate=16000)._repr_html_().strip()

    rows.append({
        "Sample": sample['__key__'],
        "Audio": audio_html,
        "Transcript": txt,
        "SNR": round(snr, 2),
        "Duration": round(duration_sec, 2),
        "CPS": round(cps, 2),
    })

    if i >= MAX_SAMPLES:
        break

df = pd.DataFrame(rows)

500it [00:05, 91.80it/s] 


In [67]:
style = """
<style>
table.custom-table {
    border-collapse: collapse;
    width: 100%;
    font-family: Arial, sans-serif;
    font-size: 14px;
}
.custom-table th, .custom-table td {
    border: 1px solid #ddd;
    padding: 8px;
    vertical-align: top;
}
.custom-table tr:nth-child(even) {
    background-color: #f9f9f9;
}
.custom-table tr:hover {
    background-color: #f1f1f1;
}
.custom-table th {
    padding-top: 12px;
    padding-bottom: 12px;
    text-align: left;
    background-color: #004c7f;
    color: white;
}
audio {
    width: 120px;
}
</style>
"""

columns = ["Sample", "Audio", "Transcript", "SNR", "Duration", "CPS"]
records = df.to_dict(orient="records")

table_html = style
table_html += "<table class='custom-table'>"
table_html += "<thead><tr>" + "".join(f"<th>{col}</th>" for col in columns) + "</tr></thead><tbody>"

for row in records:
    table_html += "<tr>" + "".join(f"<td>{row.get(col, '')}</td>" for col in columns) + "</tr>"

table_html += "</tbody></table>"

with open("voxceleb_vad_samples.html", "w") as f:
    f.write(table_html)
