In [21]:
import os
import torchaudio
from datetime import timedelta


def split_waveform_by_timestamps(input_file, output_dir, timestamps):
    mono_waveform, sample_rate = torchaudio.load(input_file)
    os.makedirs(output_dir, exist_ok=True)

    for i, (start, end, speaker) in enumerate(timestamps):
        start_frame = int(start * sample_rate)
        end_frame = int(end * sample_rate)
        segment = mono_waveform[0:, start_frame:end_frame]
        output_file = os.path.join(output_dir, f"{start}_{end}_{speaker}.wav")

        torchaudio.save(output_file, segment, sample_rate)


def aggregate_timestamps(timestamps):
    aggregated = []
    previous = timestamps[0]
    for timestamp in timestamps:
        start, end, speaker = timestamp
        prevstart, prevend, prevspeaker = previous
        if speaker == prevspeaker:
            previous = (prevstart, end, speaker)
        else:
            aggregated.append(previous)
            previous = timestamp
    return aggregated


def timeToSeconds(time):
    hhmmss = time.split(",")[0]
    ms = time.split(",")[1]
    hh = hhmmss.split(":")[0]
    mm = hhmmss.split(":")[1]
    ss = hhmmss.split(":")[2]
    seconds = timedelta(
        hours=int(hh), minutes=int(mm), seconds=int(ss), milliseconds=int(ms)
    )
    return seconds.total_seconds()

In [22]:
ROOT = os.getcwd()

sample_wav = "001_-_Scrambling_Eggs.wav"

sample_src = os.path.join(ROOT, "sample", "001_-_Scrambling_Eggs.srt")

subtitles = (
    open(sample_src, encoding="utf-8-sig").read().replace("\n\n", "\n").splitlines()
)

subtitles[0:3]

['1', '00:00:06,540 --> 00:00:06,067', 'Speaker 1: Hello.']

In [24]:
timestamps = []
speech = []
arrow = " --> "
for line in subtitles:
    if arrow in line:
        timestamps.append((line.split(arrow)[0], line.split(arrow)[1]))
    if ": " in line:
        speech.append((line.split(": ")[0], line.split(": ")[1]))

In [25]:
# split wav matches .txt/.srt file
# split using json matches embeddings

transcripts = [
    (timeToSeconds(start), timeToSeconds(end), speaker)
    for (((start, end)), ((speaker, text))) in zip(timestamps, speech)
]

aggregated = aggregate_timestamps(transcripts)

split_waveform_by_timestamps(sample_wav, "segments", aggregated)