In [2]:
from dataclasses import dataclass
import torch
import torchaudio
import matplotlib.pyplot as plt
import pandas as pd
import os
from tqdm import tqdm
import re
from torch.utils.data import DataLoader

torch.random.manual_seed(0)

<torch._C.Generator at 0x7fdac849beb0>

In [3]:
dataset_path_base = '../dataset/fluent_speech_commands_dataset/'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
def hook(module, input, output):
    global layer_output
    layer_output = output

In [5]:
def get_trellis(emission, tokens, blank_id=0):
    num_frame = emission.size(0)
    num_tokens = len(tokens)

    trellis = torch.zeros((num_frame, num_tokens))
    trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
    trellis[0, 1:] = -float("inf")
    trellis[-num_tokens + 1 :, 0] = float("inf")

    for t in range(num_frame - 1):
        trellis[t + 1, 1:] = torch.maximum(
            # Score for staying at the same token
            trellis[t, 1:] + emission[t, blank_id],
            # Score for changing to the next token
            trellis[t, :-1] + emission[t, tokens[1:]],
        )
    return trellis

In [6]:
@dataclass
class Point:
    token_index: int
    time_index: int
    score: float


def backtrack(trellis, emission, tokens, blank_id=0):
    t, j = trellis.size(0) - 1, trellis.size(1) - 1

    path = [Point(j, t, emission[t, blank_id].exp().item())]
    while j > 0:
        # Should not happen but just in case
        assert t > 0

        # 1. Figure out if the current position was stay or change
        # Frame-wise score of stay vs change
        p_stay = emission[t - 1, blank_id]
        p_change = emission[t - 1, tokens[j]]

        # Context-aware score for stay vs change
        stayed = trellis[t - 1, j] + p_stay
        changed = trellis[t - 1, j - 1] + p_change

        # Update position
        t -= 1
        if changed > stayed:
            j -= 1

        # Store the path with frame-wise probability.
        prob = (p_change if changed > stayed else p_stay).exp().item()
        path.append(Point(j, t, prob))

    # Now j == 0, which means, it reached the SoS.
    # Fill up the rest for the sake of visualization
    while t > 0:
        prob = emission[t - 1, blank_id].exp().item()
        path.append(Point(j, t - 1, prob))
        t -= 1

    return path[::-1]


In [7]:
# Merge the labels
@dataclass
class Segment:
    label: str
    start: int
    end: int
    score: float

    def __repr__(self):
        return f"{self.label} ({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"

    @property
    def length(self):
        return self.end - self.start


def merge_repeats(path, transcript):
    i1, i2 = 0, 0
    segments = []
    while i1 < len(path):
        while i2 < len(path) and path[i1].token_index == path[i2].token_index:
            i2 += 1
        score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
        segments.append(
            Segment(
                transcript[path[i1].token_index],
                path[i1].time_index,
                path[i2 - 1].time_index + 1,
                score,
            )
        )
        i1 = i2
    return segments

In [8]:
# Merge words
def merge_words(segments, separator="|"):
    words = []
    i1, i2 = 0, 0
    while i1 < len(segments):
        if i2 >= len(segments) or segments[i2].label == separator:
            if i1 != i2:
                segs = segments[i1:i2]
                word = "".join([seg.label for seg in segs])
                score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
                words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
            i1 = i2 + 1
            i2 = i1
        else:
            i2 += 1
    return words


In [9]:
def getAudioAndTranscripts(file, label2id):
    SPEECH_FILES = []
    TRANSCRIPTS = []
    TARGETS = []
    df = pd.read_csv(os.path.join(dataset_path_base+'/data/',file))

    for i in tqdm(range(0,len(df))):
        audio_file = os.path.join(dataset_path_base,df.loc[i,'path'])
        SPEECH_FILES.append(audio_file)
        TRANSCRIPTS.append(df.loc[i,'transcription'])
        object = df.loc[i,'object']
        location = df.loc[i,'location']
        labels = []
        for v in TRANSCRIPTS[i].split(" "):
            if(v == object):
                labels.append(label2id['object'])
            elif(v == location):
                labels.append(label2id['location'])
            else:
                labels.append(label2id['O'])
        TARGETS.append([*labels])
    return SPEECH_FILES, TRANSCRIPTS, TARGETS

In [10]:
label_names = ['O','object', 'location']
    
label2id = {k: v for v, k in enumerate(label_names)}
id2label = {v: k for v, k in enumerate(label_names)}

audio_paths_train, transcripts_train, targets_train = getAudioAndTranscripts('train_data.csv', label2id)
audio_paths_valid, transcripts_valid, targets_valid = getAudioAndTranscripts('valid_data.csv', label2id)
audio_paths_test, transcripts_test, targets_test = getAudioAndTranscripts('test_data.csv', label2id)

 18%|█▊        | 4227/23132 [00:00<00:00, 42264.90it/s]

100%|██████████| 23132/23132 [00:00<00:00, 42740.61it/s]
100%|██████████| 3118/3118 [00:00<00:00, 42131.40it/s]
100%|██████████| 3793/3793 [00:00<00:00, 43572.56it/s]


In [11]:
# class dataset(Dataset):
#     def __init__(self, audio_paths, transcripts, targets, max_len):
#         self.len = len(audio_paths)
#         self.audio_paths = audio_paths
#         self.transcripts = transcripts
#         self.targets = targets
#         self.max_len = max_len
        
#     def __getitem__(self, index):
#         # step 1: tokenize (and adapt corresponding labels)
#         sentence = self.data.transcription[index]  
#         word_labels = self.data.labels[index]  
#         tokenized_sentence, labels = tokenize_and_preserve_labels(sentence, word_labels, self.tokenizer)
        
#         # step 2: add special tokens (and corresponding labels)
#         tokenized_sentence = ["[CLS]"] + tokenized_sentence + ["[SEP]"] # add special tokens
#         print(tokenized_sentence)
#         labels.insert(0, "O") # add outside label for [CLS] token
#         labels.insert(-1, "O") # add outside label for [SEP] token

#         # step 3: truncating/padding
#         maxlen = self.max_len

#         if (len(tokenized_sentence) > maxlen):
#           # truncate
#           tokenized_sentence = tokenized_sentence[:maxlen]
#           labels = labels[:maxlen]
#         else:
#           # pad
#           tokenized_sentence = tokenized_sentence + ['[PAD]' for _ in range(maxlen - len(tokenized_sentence))]
#           labels = labels + ["O" for _ in range(maxlen - len(labels))]

#         # step 4: obtain the attention mask
#         attn_mask = [1 if tok != '[PAD]' else 0 for tok in tokenized_sentence]
        
#         # step 5: convert tokens to input ids
#         ids = self.tokenizer.convert_tokens_to_ids(tokenized_sentence)
#         print(ids)

#         label_ids = [label2id[label] for label in labels]
        
#         return {
#               'ids': torch.tensor(ids, dtype=torch.long),
#               'mask': torch.tensor(attn_mask, dtype=torch.long),
#               #'token_type_ids': torch.tensor(token_ids, dtype=torch.long),
#               'targets': torch.tensor(label_ids, dtype=torch.long)
#         } 
    
    
#     def __len__(self):
#         return self.len


In [12]:
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)
labels = bundle.get_labels()
maxlen = 128

# 11 is the last layer
layer = model.encoder.transformer.layers[11].final_layer_norm
hook_handle = layer.register_forward_hook(hook)
special_tokens= torch.load('./special_token.pt')
for path, transcript, target in zip(audio_paths_train,transcripts_train, targets_train):
    waveform, _ = torchaudio.load(path)
    emissions, _ = model(waveform.to(device))
    print(emissions.shape)
    out = layer_output
    out = out[0].detach().cpu()
    chars_to_ignore_regex = '[\,\?\.\-\;\:\’]'
    transcript = re.sub(chars_to_ignore_regex, '', transcript)
    transcript = re.sub("\s", "|",transcript).upper()
    dictionary = {c: i for i, c in enumerate(labels)}
    tokens = [dictionary[c] for c in transcript]
    
    trellis = get_trellis(out, tokens)
    transition_path = backtrack(trellis, out, tokens)
    segments = merge_repeats(transition_path, transcript)
    word_segments = merge_words(segments)
    word_embeds = []
    for word in word_segments:
        word_embeds.append(torch.mean(out[word.start:word.end],0,True))

    word_embeds = torch.stack(word_embeds)[0]

    pad = torch.stack([special_tokens['[PAD]'] for _ in range(maxlen - len(word_embeds))])
    tokenized_sentence = torch.cat((word_embeds,pad))

    label_ids = torch.tensor(target)
    pad = torch.tensor([label2id['O'] for _ in range(maxlen - len(target))])
    label_ids = torch.cat((label_ids,pad))

    break


hook_handle.remove()

torch.Size([1, 92, 29])
