# Import

In [17]:
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import pandas as pd
import os
import librosa
from dataclasses import dataclass
import re

# Config

In [18]:
dataset_path_base = '../dataset/fluent_speech_commands_dataset/'
model_name = 'vasista22/ccc-wav2vec2-base-100h'

# Forced Allignment

In [19]:
def get_trellis(emission, tokens, blank_id=0):
    num_frame = emission.size(0)
    num_tokens = len(tokens)

    trellis = torch.zeros((num_frame, num_tokens))
    trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
    trellis[0, 1:] = -float("inf")
    trellis[-num_tokens + 1 :, 0] = float("inf")

    for t in range(num_frame - 1):
        trellis[t + 1, 1:] = torch.maximum(
            # Score for staying at the same token
            trellis[t, 1:] + emission[t, blank_id],
            # Score for changing to the next token
            trellis[t, :-1] + emission[t, tokens[1:]],
        )
    return trellis

@dataclass
class Point:
    token_index: int
    time_index: int
    score: float


def backtrack(trellis, emission, tokens, blank_id=0):
    t, j = trellis.size(0) - 1, trellis.size(1) - 1

    path = [Point(j, t, emission[t, blank_id].exp().item())]
    while j > 0:
        # Should not happen but just in case
        assert t > 0

        # 1. Figure out if the current position was stay or change
        # Frame-wise score of stay vs change
        p_stay = emission[t - 1, blank_id]
        p_change = emission[t - 1, tokens[j]]

        # Context-aware score for stay vs change
        stayed = trellis[t - 1, j] + p_stay
        changed = trellis[t - 1, j - 1] + p_change

        # Update position
        t -= 1
        if changed > stayed:
            j -= 1

        # Store the path with frame-wise probability.
        prob = (p_change if changed > stayed else p_stay).exp().item()
        path.append(Point(j, t, prob))

    # Now j == 0, which means, it reached the SoS.
    # Fill up the rest for the sake of visualization
    while t > 0:
        prob = emission[t - 1, blank_id].exp().item()
        path.append(Point(j, t - 1, prob))
        t -= 1

    return path[::-1]

# Merge the labels
@dataclass
class Segment:
    label: str
    start: int
    end: int
    score: float

    def __repr__(self):
        return f"{self.label} ({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"

    @property
    def length(self):
        return self.end - self.start


def merge_repeats(path, transcript):
    i1, i2 = 0, 0
    segments = []
    while i1 < len(path):
        while i2 < len(path) and path[i1].token_index == path[i2].token_index:
            i2 += 1
        score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
        segments.append(
            Segment(
                transcript[path[i1].token_index],
                path[i1].time_index,
                path[i2 - 1].time_index + 1,
                score,
            )
        )
        i1 = i2
    return segments

# Merge words
def merge_words(segments, separator="|"):
    words = []
    i1, i2 = 0, 0
    while i1 < len(segments):
        if i2 >= len(segments) or segments[i2].label == separator:
            if i1 != i2:
                segs = segments[i1:i2]
                word = "".join([seg.label for seg in segs])
                score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
                words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
            i1 = i2 + 1
            i2 = i1
        else:
            i2 += 1
    return words


# Model Selection

In [20]:
# load model and tokenizer
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)

# Data

In [28]:
train_df = pd.read_csv(os.path.join(dataset_path_base+'/data','train_data.csv'))
train_df.head(10)

Unnamed: 0.1,Unnamed: 0,path,speakerId,transcription,action,object,location
0,0,wavs/speakers/2BqVo8kVB2Skwgyb/0a3129c0-4474-1...,2BqVo8kVB2Skwgyb,Change language,change language,none,none
1,1,wavs/speakers/2BqVo8kVB2Skwgyb/0ee42a80-4474-1...,2BqVo8kVB2Skwgyb,Resume,activate,music,none
2,2,wavs/speakers/2BqVo8kVB2Skwgyb/144d5be0-4474-1...,2BqVo8kVB2Skwgyb,Turn the lights on,activate,lights,none
3,3,wavs/speakers/2BqVo8kVB2Skwgyb/1811b6e0-4474-1...,2BqVo8kVB2Skwgyb,Switch on the lights,activate,lights,none
4,4,wavs/speakers/2BqVo8kVB2Skwgyb/1d9f3920-4474-1...,2BqVo8kVB2Skwgyb,Switch off the lights,deactivate,lights,none
5,5,wavs/speakers/2BqVo8kVB2Skwgyb/269fc210-4474-1...,2BqVo8kVB2Skwgyb,Volume up,increase,volume,none
6,6,wavs/speakers/2BqVo8kVB2Skwgyb/5bbda3f0-4478-1...,2BqVo8kVB2Skwgyb,Turn the volume up,increase,volume,none
7,7,wavs/speakers/2BqVo8kVB2Skwgyb/6436ad60-4478-1...,2BqVo8kVB2Skwgyb,Turn the volume down,decrease,volume,none
8,8,wavs/speakers/2BqVo8kVB2Skwgyb/6a1cd6f0-4478-1...,2BqVo8kVB2Skwgyb,Turn up the temperature,increase,heat,none
9,9,wavs/speakers/2BqVo8kVB2Skwgyb/72160200-4478-1...,2BqVo8kVB2Skwgyb,Turn the heat up,increase,heat,none


In [22]:
def hook(module, input, output):
    global layer_output
    layer_output = output

In [23]:
import torchaudio
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
labels = bundle.get_labels()

In [24]:
layer = model.wav2vec2.encoder.layers[11].final_layer_norm
hook_handle = layer.register_forward_hook(hook)

In [32]:

def wordify(idx):
    filePath = os.path.join(dataset_path_base,train_df.loc[idx,'path'])
    audio,rate = librosa.load(filePath, sr=16000)
    transcript = train_df.loc[idx,'transcription']
    # tokenize
    input_values = processor(audio, sampling_rate=rate, return_tensors="pt", padding="longest").input_values  # Batch size 1

    # retrieve logits
    with torch.no_grad():
        _= model(input_values)

    out = layer_output[0]
    # take argmax and decode    
    chars_to_ignore_regex = '[\,\?\.\-\;\:\’]'
    transcript = re.sub(chars_to_ignore_regex, '', transcript)
    transcript = re.sub("\s", "|",transcript).upper()

    
    dictionary = {c: i for i, c in enumerate(labels)}
    tokens = [dictionary[c] for c in transcript]
    trellis = get_trellis(out, tokens)
    transition_path = backtrack(trellis, out, tokens)
    segments = merge_repeats(transition_path, transcript)
    word_segments = merge_words(segments)
    word_embeds = []
    for word in word_segments:
        word_embeds.append(torch.mean(F.normalize(out[word.start:word.end]),0,True))

    return word_embeds, transcript

In [33]:
tokenized_op1, transcript_1 = wordify(3)
tokenized_op2, transcript_2 = wordify(4)

In [35]:
import torch.nn as nn
cos = nn.CosineSimilarity(dim=1, eps=1e-6)

print('\t'.join(transcript_1.split('|')))
for i in range(4):
    output = cos(tokenized_op1[i], tokenized_op2[i])
    print(f'{output.item():0.4f}', end='\t')
print()
print('\t'.join(transcript_2.split('|')))

SWITCH	ON	THE	LIGHTS
0.9673	0.6662	0.9267	0.9521	
SWITCH	OFF	THE	LIGHTS


In [36]:
tokenized_op1, transcript_1 = wordify(8)
tokenized_op2, transcript_2 = wordify(9)

In [37]:
import torch.nn as nn
cos = nn.CosineSimilarity(dim=1, eps=1e-6)

print('\t'.join(transcript_1.split('|')))
for i in range(4):
    output = cos(tokenized_op1[i], tokenized_op2[i])
    print(f'{output.item():0.4f}', end='\t')
print()
print('\t'.join(transcript_2.split('|')))

TURN	UP	THE	TEMPERATURE
0.7121	0.6039	0.7888	0.5248	
TURN	THE	HEAT	UP


In [38]:
tokenized_op1, transcript_1 = wordify(6)
tokenized_op2, transcript_2 = wordify(7)

In [39]:
import torch.nn as nn
cos = nn.CosineSimilarity(dim=1, eps=1e-6)

print('\t'.join(transcript_1.split('|')))
for i in range(4):
    output = cos(tokenized_op1[i], tokenized_op2[i])
    print(f'{output.item():0.4f}', end='\t')
print()
print('\t'.join(transcript_2.split('|')))

TURN	THE	VOLUME	UP
0.7726	0.9083	0.9074	0.8004	
TURN	THE	VOLUME	DOWN
