# Alignment Decoding task from UCU Acoustic School

## Given code

From [here](https://gist.github.com/proger/a7e820fbfa0181273fdbf2351901d0d8), with minor improvements.

In [81]:
import torch
import torch.nn as nn
import torchaudio
from g2p_en import G2p
from tqdm import tqdm

from praatio import textgrid as tgio
from praatio.data_classes.interval_tier import Interval

In [82]:
def make_frames(wav):
    return torchaudio.compliance.kaldi.mfcc(wav)

In [83]:
class LibriSpeech(torch.utils.data.Dataset):
    def __init__(self, url="dev-clean"):
        super().__init__()
        self.librispeech = torchaudio.datasets.LIBRISPEECH(".", url=url, download=True)

    def __len__(self):
        return len(self.librispeech)

    def __getitem__(self, index):
        wav, sr, text, speaker_id, chapter_id, utterance_id = self.librispeech[index]
        return make_frames(wav), text, (speaker_id, chapter_id, utterance_id) # returning additional ids

In [84]:
class Encoder(nn.Module):
    def __init__(self, input_dim=13, subsample_dim=128, hidden_dim=1024):
        super().__init__()
        self.subsample = nn.Conv1d(input_dim, subsample_dim, 5, stride=4, padding=3)
        self.lstm = nn.LSTM(
            subsample_dim, hidden_dim, batch_first=True, num_layers=3, dropout=0.2
        )

    def subsampled_lengths(self, input_lengths):
        # https://github.com/vdumoulin/conv_arithmetic
        p, k, s = (
            self.subsample.padding[0],
            self.subsample.kernel_size[0],
            self.subsample.stride[0],
        )
        o = input_lengths + 2 * p - k
        o = torch.floor(o / s + 1)
        return o.int()

    def forward(self, inputs):
        x = inputs
        x = self.subsample(x.mT).mT
        x = x.relu()
        x, _ = self.lstm(x)
        return x.relu()

In [85]:
class Vocabulary:
    def __init__(self):
        self.g2p = G2p()

        # http://www.speech.cs.cmu.edu/cgi-bin/cmudict
        self.rdictionary = ["ε", # CTC blank
                            " ",
                            "AA0", "AA1", "AE0", "AE1", "AH0", "AH1", "AO0", "AO1", "AW0", "AW1", "AY0", "AY1",
                            "B", "CH", "D", "DH",
                            "EH0", "EH1", "ER0", "ER1", "EY0", "EY1",
                            "F", "G", "HH",
                            "IH0", "IH1", "IY0", "IY1",
                            "JH", "K", "L", "M", "N", "NG",
                            "OW0", "OW1", "OY0", "OY1",
                            "P", "R", "S", "SH", "T", "TH",
                            "UH0", "UH1", "UW0", "UW1",
                            "V", "W", "Y", "Z", "ZH"]

        self.dictionary = {c: i for i, c in enumerate(self.rdictionary)}

    def __len__(self):
        return len(self.rdictionary)

    def encode(self, text):
        labels = [c.replace('2', '0') for c in self.g2p(text) if c != "'"]
        targets = torch.LongTensor([self.dictionary[phoneme] for phoneme in labels])
        return targets

    def decode(self, targets): # function to get symbol from token id
        return [self.rdictionary[token] for token in targets]

In [86]:
class Recognizer(nn.Module):
    def __init__(self, feat_dim=1024, vocab_size=55+1):
        super().__init__()
        self.classifier = nn.Linear(feat_dim, vocab_size)

    def forward(self, features):
        features = self.classifier(features)
        return features.log_softmax(dim=-1)

In [87]:
vocab = Vocabulary()
encoder = Encoder()
recognizer = Recognizer()

In [88]:
ckpt = torch.load('lstm_p3_360+500.pt', map_location='cpu')
encoder.load_state_dict(ckpt['encoder'])
recognizer.load_state_dict(ckpt['recognizer'])

<All keys matched successfully>

In [89]:
audio_frames, text, ids = LibriSpeech()[0]
phonemes = vocab.encode(text)

In [90]:
features = encoder(audio_frames)
outputs = recognizer.forward(features) # (T, 55+1)

## My code

Sample text:

In [91]:
text

'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'

In [92]:
print('real phonemes:       ', ''.join(vocab.decode(phonemes)))
print(f'predicted phonemes: ', ''.join(vocab.decode(torch.argmax(outputs, dim=1))))

real phonemes:        MIH1STER0 KWIH1LTER0 IH1Z DHAH0 AH0PAA1SAH0L AH1V DHAH0 MIH1DAH0L KLAE1SAH0Z AH0ND WIY1 AA1R GLAE1D TUW1 WEH1LKAH0M HHIH1Z GAA1SPAH0L
predicted phonemes:  εεεεεεεεεεεεεεεεεεεεMIH1STER0  KRIH1LTεεER0ε  εIH1Z DHAH0  AH0εPPεAA1SεAH0AH0LL  AH1V DHAH0 MIH1DAH0LL KLεAE1εSεAH0AH0εZεε  εAH0ND WIH1εRR GLLεAE1DDε TTWWWEH1LKKεAH0M   HHIH1Zε  GεAA1SSPPεAH0Lεεεεεεεε


Sample output size vs sample audio size

In [93]:
outputs.size()[0], audio_frames.size()[0]

(147, 584)

Code to segment all audio frames and save results

In [94]:
output_dir = 'result'

In [102]:
utt_i=0
for audio_frames, text, ids in tqdm(LibriSpeech()):
    # try:
    features = encoder(audio_frames)
    outputs = recognizer.forward(features)

    tg = tgio.Textgrid()
    tg.minTimestamp = 0
    tg.maxTimestamp = audio_frames.size()[0] / 100

    tier_name = 'phones'
    phones_tier = tgio.IntervalTier(tier_name, [], minT=0, maxT=tg.maxTimestamp)

    intervals = []
    output_tokens = torch.argmax(outputs, dim=1)
    decoded_output_tokens = vocab.decode(output_tokens)
    prev_token, prev_start = None, 0

    for i, token in enumerate(decoded_output_tokens):
        # print(f'{i}: {token=}, {prev_token=}, {prev_start=};')
        if prev_token != token and prev_token:
            intervals.append(Interval(prev_start, i/25, prev_token))
            prev_token = token
            prev_start = i/25
        elif not prev_token:
            prev_token = token
    if prev_token:
        intervals.append(Interval(prev_start, tg.maxTimestamp, prev_token))

    new_phonemes_tier = phones_tier.new(entries=intervals)
    tg.addTier(new_phonemes_tier)

    tg.save(f'{output_dir}/{ids[0]}_{ids[1]}_{ids[2]}.TextGrid',
            includeBlankSpaces=True,
            format='long_textgrid',
            reportingMode='error')
    # except Exception as e:
    #     print(e)
    #     print(utt_i, ids)
    utt_i += 1

  0%|          | 4/2703 [00:04<54:46,  1.22s/it]


KeyboardInterrupt: 