# wav2vec_test
- force align each annotation/wav snippet 
- save results in json file

In [1]:
import os
import json
import torch
import IPython
import torchaudio
import matplotlib.pyplot as plt
from dataclasses import dataclass


print(torch.__version__)
print(torchaudio.__version__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

torch.random.manual_seed(0)

2.1.2+cpu
2.1.2+cpu
cpu


<torch._C.Generator at 0x279f9ed9110>

In [68]:
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)
labels = bundle.get_labels()

print(labels)

('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')


In [69]:
folder_path = 'C:\\Users\\barth\\gits\\pytorch_wav2vec\\test_data\\'
inputPath = folder_path + "input\\"
mediaPath = folder_path + "media_snippets\\"

def read_json_file(filename):
    try:
        with open(filename, 'r') as json_file:
            data = json.load(json_file)
            return data
    except Exception as e:
        print(f"Error reading JSON file: {e}")
        return None

def get_trellis(emission, tokens, blank_id=0):
    num_frame = emission.size(0)
    num_tokens = len(tokens)

    trellis = torch.zeros((num_frame, num_tokens))
    trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
    trellis[0, 1:] = -float("inf")
    trellis[-num_tokens + 1 :, 0] = float("inf")

    for t in range(num_frame - 1):
        trellis[t + 1, 1:] = torch.maximum(
            # Score for staying at the same token
            trellis[t, 1:] + emission[t, blank_id],
            # Score for changing to the next token
            trellis[t, :-1] + emission[t, tokens[1:]],
        )
    return trellis

def plot():
    fig, ax = plt.subplots()
    img = ax.imshow(trellis.T, origin="lower")
    ax.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
    ax.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))
    fig.colorbar(img, ax=ax, shrink=0.6, location="bottom")
    fig.tight_layout()

@dataclass
class Point:
    token_index: int
    time_index: int
    score: float

def backtrack(trellis, emission, tokens, blank_id=0):
    t, j = trellis.size(0) - 1, trellis.size(1) - 1

    path = [Point(j, t, emission[t, blank_id].exp().item())]
    while j > 0:
        # Should not happen but just in case
        assert t > 0

        # 1. Figure out if the current position was stay or change
        # Frame-wise score of stay vs change
        p_stay = emission[t - 1, blank_id]
        p_change = emission[t - 1, tokens[j]]

        # Context-aware score for stay vs change
        stayed = trellis[t - 1, j] + p_stay
        changed = trellis[t - 1, j - 1] + p_change

        # Update position
        t -= 1
        if changed > stayed:
            j -= 1

        # Store the path with frame-wise probability.
        prob = (p_change if changed > stayed else p_stay).exp().item()
        path.append(Point(j, t, prob))

    # Now j == 0, which means, it reached the SoS.
    # Fill up the rest for the sake of visualization
    while t > 0:
        prob = emission[t - 1, blank_id].exp().item()
        path.append(Point(j, t - 1, prob))
        t -= 1

    return path[::-1]

# Merge the labels
@dataclass
class Segment:
    label: str
    start: int
    end: int
    score: float

    def __repr__(self):
        return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"

    @property
    def length(self):
        return self.end - self.start


def merge_repeats(path):
    i1, i2 = 0, 0
    segments = []
    while i1 < len(path):
        while i2 < len(path) and path[i1].token_index == path[i2].token_index:
            i2 += 1
        score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
        segments.append(
            Segment(
                transcript[path[i1].token_index],
                path[i1].time_index,
                path[i2 - 1].time_index + 1,
                score,
            )
        )
        i1 = i2
    return segments

# Merge words
def merge_words(segments, separator="|"):
    words = []
    i1, i2 = 0, 0
    while i1 < len(segments):
        if i2 >= len(segments) or segments[i2].label == separator:
            if i1 != i2:
                segs = segments[i1:i2]
                word = "".join([seg.label for seg in segs])
                score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
                words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
            i1 = i2 + 1
            i2 = i1
        else:
            i2 += 1
    return words

def segment_info(i):
    ratio = waveform.size(1) / trellis.size(0)
    word = word_segments[i]
    x0 = int(ratio * word.start)
    x1 = int(ratio * word.end)
    #print(f"{word.label} ({word.score:.2f}): {x0 / 44100:.3f} - {x1 / 44100:.3f} sec")
    return ([f"{word.label}", format(x0 / 44100,'.3f') , format(x1 / 44100, '.3f')])
    

for filename in os.listdir(mediaPath):
    counter  = 0
    
    if filename.endswith('.json'):
        jsonFile = mediaPath + filename
        json_data = read_json_file(jsonFile)
        for k, v in json_data.items():
            
            text_clean = v["text"].upper().replace(' ', '|')
            #print (text_clean)
            transcript = text_clean
            dictionary = {c: i for i, c in enumerate(labels)}
            
            tokens = [dictionary[c] for c in transcript]
            #print(list(zip(transcript, tokens)))

            SPEECH_FILE = "C:/Users/barth/gits/pytorch_wav2vec/test_data/media_snippets/" + filename[:-4] + "\\" + k.replace(" ", "_") + ".wav" #Cathy_Samun_Wiliang_a1.wav"
            
            #bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
            
            #model = bundle.get_model().to(device)
            #labels = bundle.get_labels()
            
            
            with torch.inference_mode():
                waveform, _ = torchaudio.load(SPEECH_FILE)
                emissions, _ = model(waveform.to(device))
                emissions = torch.log_softmax(emissions, dim=-1)
            
            emission = emissions[0].cpu().detach()
            metadata = torchaudio.info(SPEECH_FILE)
            print(metadata)
            
            trellis = get_trellis(emission, tokens)
           
            #plot()
            
            path = backtrack(trellis, emission, tokens)

            segments = merge_repeats(path)
            
                
            word_segments = merge_words(segments)
            word_list = []
            for x in range(0, len(word_segments)):
                wordInfo = segment_info(x)
                wordInfo[1] = int(float(wordInfo[1]) * 1000) + v['timeStamp1']
                wordInfo[2] = int(float(wordInfo[2]) * 1000) + v['timeStamp1']
                print(wordInfo)
                
                word_list.append(wordInfo)


            print (word_list)
            print(v)
            v["tier_new"] = v["tiername"] + "_words"
            v["words"] = word_list
                
            
            
            #print(transcript)
            #display_segment(-1)
            #IPython.display.Audio(SPEECH_FILE, rate=44100)

        print (json_data)
        with open(jsonFile, 'w') as json_file:
            json.dump(json_data, json_file, indent=2)

print (" +++ DONE +++" )

AudioMetaData(sample_rate=44100, num_frames=121275, num_channels=2, bits_per_sample=16, encoding=PCM_S)
['YANGAU', 4730, 5407]
['CATHY', 6294, 6687]
['SAMUN', 6745, 7116]
['AI', 7276, 7480]
[['YANGAU', 4730, 5407], ['CATHY', 6294, 6687], ['SAMUN', 6745, 7116], ['AI', 7276, 7480]]
{'tiername': 'Cathy Samun Wiliang', 'annoID': 'a1', 'timeSlotRef1': 'ts1', 'timeSlotRef2': 'ts2', 'timeStamp1': 4730, 'timeStamp2': 7480, 'text': 'Yangau Cathy Samun ai'}
AudioMetaData(sample_rate=44100, num_frames=84672, num_channels=2, bits_per_sample=16, encoding=PCM_S)
['NGAHAU', 8860, 9216]
['MAM', 9296, 9427]
['YANGAN', 9507, 9653]
['DAIDAI', 9682, 10067]
['AI', 10075, 10780]
[['NGAHAU', 8860, 9216], ['MAM', 9296, 9427], ['YANGAN', 9507, 9653], ['DAIDAI', 9682, 10067], ['AI', 10075, 10780]]
{'tiername': 'Cathy Samun Wiliang', 'annoID': 'a2', 'timeSlotRef1': 'ts3', 'timeSlotRef2': 'ts4', 'timeStamp1': 8860, 'timeStamp2': 10780, 'text': 'Ngahau mam yangan Daidai ai'}
AudioMetaData(sample_rate=44100, num_fr

In [18]:
def display_segment(i):
    ratio = waveform.size(1) / trellis.size(0)
    word = word_segments[i]
    x0 = int(ratio * word.start)
    x1 = int(ratio * word.end)
    print(f"{word.label} ({word.score:.2f}): {x0 / 44100:.3f} - {x1 / 44100:.3f} sec")
    segment = waveform[:, x0:x1]

    return IPython.display.Audio(segment.numpy(), rate=44100)

display_segment(1)

TAMAT (0.78): 0.618 - 0.851 sec


In [8]:
IPython.display.Audio(SPEECH_FILE, rate=44100)

In [None]:
"C:\Users\barth\gits\pytorch_wav2vec\test_data\media_snippets\Cathy_Samun_Wiliang_a1.wav"
"C:\Users\barth\gits\pytorch_wav2vec\test_data\media_snippets\Cathy Samun Wiliang_a1.wav"

"C:\Users\barth\gits\pytorch_wav2vec\test_data\media_snippets\Cathy_Samun_Wiliang_a53.wav"