In [13]:
import os, re, json
import copy
from IPython.display import Audio, display
import pandas as pd
import torch
import whisper
from typing import List

In [14]:
from utils import VAD, ASR

In [15]:
SR = 16_000
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [16]:
# Grab a short file for testing if it doesn't exist.
example_file = './data/en_example.wav'
if not os.path.exists(example_file):
    torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', example_file)

In [17]:
Audio(example_file, rate = SR)

In [18]:
# import utils.prepare as prep
# _ = prep.get_vad_data(example_file, refresh=True)
# _ = prep.get_asr_data(example_file, refresh=True, model_name='medium.en')

### Caching Data

In [19]:
def cache_json(data: object, path: str):
    with open(path, 'w') as f:
        json.dump(data, f)

In [20]:
def load_json(path):
    with open(path, 'r') as f:
        data = json.load(f)
    return data

## From Audio to VAD segments with cacheing

In [21]:
def get_waveform(filepath: str):
    """Returns a tensor from an audio file."""
    wav = whisper.load_audio(filepath)
    return torch.from_numpy(wav)

In [22]:
def new_vad_data(audio):
    vad = VAD()
    segments = vad.get_speech_timestamps(audio)
    return segments

In [23]:
def get_vad_data(path:str, waveform=None, refresh=False):
    json_path = f"{path}_vad.json"
    if os.path.exists(json_path) and not refresh:
        segments = load_json(json_path)
    else:
        if waveform == None:
            waveform = get_waveform(path)
        print('Calculating VAD segments.')
        segments = new_vad_data(waveform)
        cache_json(segments, json_path)
        print(f"Done. Data saved: '{json_path}'")
    return segments

In [24]:
segs = get_vad_data(example_file)

### Using VAD segments to extract speech only

In [25]:
def collect_chunks(wav: torch.Tensor, segments: List[dict]):
    """Edits a waveform to include only the segements in a list of segments."""
    chunks = []
    for i in segments:
        chunks.append(wav[i['start']: i['end']])
    return torch.cat(chunks)


In [26]:
def align_chunks(segments: List[dict]):
    """Returns a list of segments to align with the waveform generated by the collect chunks function."""
    chunks = []
    current_frame = 0
    for entry in segments:
        speech_length = entry['end'] - entry['start']
        end_frame = current_frame + speech_length
        chunks.append(
            {'start': current_frame,
             'end': end_frame}
            )
        current_frame = end_frame
    return chunks

In [27]:
def drop_chunks(tss: List[dict],
                wav: torch.Tensor):
    chunks = []
    cur_start = 0
    for i in tss:
        chunks.append((wav[cur_start: i['start']]))
        cur_start = i['end']
    return torch.cat(chunks)

In [28]:
def collect_and_align_chunks(wav: torch.Tensor, segments: List[dict]):
    new_wav = collect_chunks(wav, segments)
    new_segs = align_chunks(segments)
    return new_wav, new_segs

In [29]:
wav2, segs2 = collect_and_align_chunks(get_waveform(example_file), get_vad_data(example_file))

## From Audio to ASR data with cacheing

In [30]:
def new_asr_data(audio, verbose=False):
    asr = ASR()
    results = asr.transcribe(audio, verbose=verbose)
    return results['segments']


In [31]:
def get_asr_data(path, waveform=None, refresh=False):
    json_path = f"{path}_asr.json"
    if os.path.exists(json_path) and not refresh:
        asr_data = load_json(json_path)
    else:
        if waveform==None:
            waveform = get_waveform(path)
        wav = collect_chunks(waveform, get_vad_data(path, waveform=waveform))
        print('Transcribing asr data')
        asr_data = new_asr_data(wav)
        cache_json(asr_data, json_path)
        print(f"Done. Data saved: '{json_path}'")
    return asr_data

In [32]:
asr_results = get_asr_data(example_file)

In [58]:
def get_full_text(data, separator = ''):
    return separator.join([segment['text'] for segment in data])

get_full_text(asr_results)

" and says, how do I get to Dublin? And the answer that comes back is, well, I wouldn't start from here, Sonny. That is to say, much of political philosophy develops theories that take no account of where we actually are and how the theories that people argue about in the journals and in the literature actually could be implemented in the world, if at all. And this spills over into normative arguments made by other scholars. Thomas Piketty in his book, Capital in the 21st Century, argues for a 4% global wealth tax. Well, good luck with that. Who's gonna implement a 4% global wealth tax? So when I think about normative questions."

## Aligning transcription with VAD data

In [34]:
import utils.alignment as align

In [35]:
wav = get_waveform(example_file)

In [36]:
align_model, metadata = align.load_align_model('en',device='cuda')

In [37]:
ouput = align.align(asr_results, align_model, metadata, wav, 'cuda')

In [38]:
def convert_timestamp_segments_to_frames(segments, sr=SR):
    df = pd.DataFrame(segments)
    df[['start','end']] = (df[['start','end']] * sr).round().astype(int)
    return df.to_dict(orient='records')


def convert_frames_segments_to_timestamp(segments, sr=SR):
    df = pd.DataFrame(segments)
    df[['start','end']] = (df[['start','end']] / sr)
    return df.to_dict(orient='records')

In [39]:
word_segs = convert_timestamp_segments_to_frames(ouput['word_segments'])

In [65]:
class Aligner():

    def __init__(self, language='en', device = DEVICE):
        self.model, self.metadata = align.load_align_model(language_code=language,device=device)
        self.device = device

    def align(self, transcript, audio, interoplate_method = "linear"):
        results = align.align(
            transcript = transcript, 
            model = self.model, 
            align_model_metadata = metadata,
            audio = audio,
            device = self.device,
            interpolate_method = interoplate_method)
        return results

In [41]:
aligner = Aligner()
output = aligner.align(asr_results, wav2)

In [42]:
wordsegs = convert_timestamp_segments_to_frames(output['word_segments'])
# segs = convert_timestamp_segments_to_frames(output['segments'])

Now that I have much better timestamps than Whisper can provide I need to extrapolate them back to timestamps that correlate to the original audio length.  
The VAD segments are really good and already have a mapping to the original audio length since they were sourced from there.  So if I can match these words to their parent VAD segment I can then correlate them to the original audio.

In [43]:
segs2[:3]

[{'start': 0, 'end': 29632},
 {'start': 29632, 'end': 60800},
 {'start': 60800, 'end': 89920}]

In [44]:
wordsegs[:25]

[{'text': 'and', 'start': 0, 'end': 1293},
 {'text': 'says,', 'start': 2264, 'end': 6144},
 {'text': 'how', 'start': 7761, 'end': 10348},
 {'text': 'do', 'start': 10671, 'end': 12611},
 {'text': 'I', 'start': 13905, 'end': 14875},
 {'text': 'get', 'start': 15845, 'end': 18432},
 {'text': 'to', 'start': 19725, 'end': 21666},
 {'text': 'Dublin?', 'start': 22312, 'end': 28456},
 {'text': 'And', 'start': 31043, 'end': 32660},
 {'text': 'the', 'start': 33307, 'end': 34600},
 {'text': 'answer', 'start': 36541, 'end': 40421},
 {'text': 'that', 'start': 41068, 'end': 43331},
 {'text': 'comes', 'start': 44301, 'end': 48505},
 {'text': 'back', 'start': 50122, 'end': 54973},
 {'text': 'is,', 'start': 56913, 'end': 57883},
 {'text': 'well,', 'start': 61440, 'end': 64029},
 {'text': 'I', 'start': 64676, 'end': 65323},
 {'text': "wouldn't", 'start': 65970, 'end': 70501},
 {'text': 'start', 'start': 71471, 'end': 75678},
 {'text': 'from', 'start': 76649, 'end': 78914},
 {'text': 'here,', 'start': 802

In [45]:
from pprint import pprint

parents = [{'start': 0, 'end': 29632},
    {'start': 29632, 'end': 60800}]

children = [{'text': 'and', 'start': 0, 'end': 1293},
    {'text': 'says,', 'start': 2264, 'end': 6144},
    {'text': 'how', 'start': 7761, 'end': 10348},
    {'text': 'do', 'start': 10671, 'end': 12611},
    {'text': 'I', 'start': 13905, 'end': 14875},
    {'text': 'get', 'start': 15845, 'end': 18432},
    {'text': 'to', 'start': 19725, 'end': 21666},
    {'text': 'Dublin?', 'start': 22312, 'end': 28456},
    {'text': 'And', 'start': 31043, 'end': 32660},
    {'text': 'the', 'start': 33307, 'end': 34600},
    {'text': 'answer', 'start': 36541, 'end': 40421},
    {'text': 'that', 'start': 41068, 'end': 43331},
    {'text': 'comes', 'start': 44301, 'end': 48505},
    {'text': 'back', 'start': 50122, 'end': 54973},
    ]

def get_range_overlap_percent(parent: range, child: range) -> float:
    """Calculates the percentage of which the child's boundries fit within the parent's boundries."""
    olap = range(max(parent[0], child[0]), min(parent[-1], child[-1])+1)
    olap_percent = len(olap) / len(child)
    return olap_percent

def child_in_parent(parent, child, threshold=0.6) -> bool:
    s = range(parent['start'],parent['end'])
    w = range(child['start'],child['end'])
    return get_range_overlap_percent(s, w) > threshold

def gather_children(parents, children):
    results = []
    for parent in parents:
        out = []
        for idx, child in enumerate(children):
            if child_in_parent(parent, child):
                out.append(child)
        copy = parent.copy()
        copy['children'] = out
        results.append(copy)
    return results

_ = gather_children(parents, children)

pprint(parents)
pprint(_)

[{'end': 29632, 'start': 0}, {'end': 60800, 'start': 29632}]
[{'children': [{'end': 1293, 'start': 0, 'text': 'and'},
               {'end': 6144, 'start': 2264, 'text': 'says,'},
               {'end': 10348, 'start': 7761, 'text': 'how'},
               {'end': 12611, 'start': 10671, 'text': 'do'},
               {'end': 14875, 'start': 13905, 'text': 'I'},
               {'end': 18432, 'start': 15845, 'text': 'get'},
               {'end': 21666, 'start': 19725, 'text': 'to'},
               {'end': 28456, 'start': 22312, 'text': 'Dublin?'}],
  'end': 29632,
  'start': 0},
 {'children': [{'end': 32660, 'start': 31043, 'text': 'And'},
               {'end': 34600, 'start': 33307, 'text': 'the'},
               {'end': 40421, 'start': 36541, 'text': 'answer'},
               {'end': 43331, 'start': 41068, 'text': 'that'},
               {'end': 48505, 'start': 44301, 'text': 'comes'},
               {'end': 54973, 'start': 50122, 'text': 'back'}],
  'end': 60800,
  'start': 29632}]


In [46]:
srt_data = gather_children(segs2, wordsegs)

## ASR + Alignment caching

In [47]:
def new_alignment(asr_data, audio):
    aligner = Aligner()
    alignment_data = aligner.align(asr_data, audio)
    alignment_data = convert_timestamp_segments_to_frames(alignment_data['word_segments'])
    return alignment_data

In [48]:
def get_alignment_data(path, waveform=None, vad_data=None, asr_data=None, refresh=False):
    json_path = f"{path}_align.json"
    if os.path.exists(json_path) and not refresh:
        alignment_data = load_json(json_path)
    else:
        if waveform==None:
            waveform = get_waveform(path)
        if vad_data==None:
            vad_data = get_vad_data(path, waveform)
        if asr_data==None:
            asr_data = get_asr_data(path, waveform)
        print('Aligning ASR data')
        short_wav = collect_chunks(waveform, vad_data)
        alignment_data = new_alignment(asr_data, short_wav)
        cache_json(alignment_data, json_path)
        print(f"Done. Data saved: '{json_path}'")
    return alignment_data

In [49]:
def listen_segment(segment, audio):
    from IPython.display import display, Audio
    SR = 16000
    start = segment['start']
    end = segment['end']
    clip = audio[start:end]
    display(Audio(clip, rate=SR))

In [50]:
def map_alignment_data_to_vad_data(alignment_data, vad_data):
    import copy
    vad = copy.deepcopy(vad_data)
    srt_data = gather_children(vad, alignment_data)
    return srt_data

In [51]:
def migrate_children(source, target):
    """Copies child segments from source segment to the target segment and adjust timestamps to match."""
    out = copy.deepcopy(target)
    diff = target['start'] - source['start']
    out['children'] = copy.deepcopy(source['children'])
    for child in out['children']:
        child['start'] = child['start'] + diff
        child['end'] = child['end'] + diff
    return out


In [62]:
def batch_migrate_children(source, target):
    return [migrate_children(s,t) for s,t in zip(source,target)]

In [59]:
get_full_text(srt_data[0]['children'], ' ')

'and says, how do I get to Dublin?'

In [68]:
vad_data = get_vad_data(example_file)
vad_short = align_chunks(vad_data)
align_data = get_alignment_data(example_file)
srt_short = map_alignment_data_to_vad_data(align_data, vad_short)
srt_data = batch_migrate_children(srt_short, vad_data)

for segment in srt_data[:3]:
    print(get_full_text(segment['children'], ' '))
    listen_segment(segment, wav)

and says, how do I get to Dublin?


And the answer that comes back is,


well, I wouldn't start from here, Sonny.


In [71]:
def get_srt_data(path, refresh=False):
    vad_data = get_vad_data(path, refresh=refresh)
    align_data = get_alignment_data(path, refresh=refresh)
    vad_short = align_chunks(vad_data)
    srt_short = gather_children(vad_short, align_data)
    srt_data = batch_migrate_children(srt_short, vad_data)
    for segment in srt_data:
        segment['text'] = get_full_text(segment['children'], ' ')
    return srt_data


In [72]:
data = get_srt_data(example_file)
df = pd.DataFrame(data)
df.head()

Unnamed: 0,start,end,children,text
0,1568,31200,"[{'text': 'and', 'start': 1568, 'end': 2861}, ...","and says, how do I get to Dublin?"
1,42528,73696,"[{'text': 'And', 'start': 43939, 'end': 45556}...","And the answer that comes back is,"
2,79392,108512,"[{'text': 'well,', 'start': 80032, 'end': 8262...","well, I wouldn't start from here, Sonny."
3,149024,163808,"[{'text': 'That', 'start': 149666, 'end': 1525...","That is to say,"
4,166944,181728,"[{'text': 'much', 'start': 168569, 'end': 1730...",much of
