In [72]:
import torch
import torch.nn as nn
import torchaudio
from g2p_en import G2p
from tqdm import tqdm

from praatio import textgrid as tgio
from praatio.data_classes.interval_tier import Interval

In [3]:
def make_frames(wav):
    return torchaudio.compliance.kaldi.mfcc(wav)

In [62]:
class LibriSpeech(torch.utils.data.Dataset):
    def __init__(self, url="dev-clean"):
        super().__init__()
        self.librispeech = torchaudio.datasets.LIBRISPEECH(".", url=url, download=True)

    def __len__(self):
        return len(self.librispeech)

    def __getitem__(self, index):
        wav, sr, text, speaker_id, chapter_id, utterance_id = self.librispeech[index]
        return make_frames(wav), text, (speaker_id, chapter_id, utterance_id)

In [5]:
class Encoder(nn.Module):
    def __init__(self, input_dim=13, subsample_dim=128, hidden_dim=1024):
        super().__init__()
        self.subsample = nn.Conv1d(input_dim, subsample_dim, 5, stride=4, padding=3)
        self.lstm = nn.LSTM(
            subsample_dim, hidden_dim, batch_first=True, num_layers=3, dropout=0.2
        )

    def subsampled_lengths(self, input_lengths):
        # https://github.com/vdumoulin/conv_arithmetic
        p, k, s = (
            self.subsample.padding[0],
            self.subsample.kernel_size[0],
            self.subsample.stride[0],
        )
        o = input_lengths + 2 * p - k
        o = torch.floor(o / s + 1)
        return o.int()

    def forward(self, inputs):
        x = inputs
        x = self.subsample(x.mT).mT
        x = x.relu()
        x, _ = self.lstm(x)
        return x.relu()

In [57]:
class Vocabulary:
    def __init__(self):
        self.g2p = G2p()

        # http://www.speech.cs.cmu.edu/cgi-bin/cmudict
        self.rdictionary = ["ε", # CTC blank
                            " ",
                            "AA0", "AA1", "AE0", "AE1", "AH0", "AH1", "AO0", "AO1", "AW0", "AW1", "AY0", "AY1",
                            "B", "CH", "D", "DH",
                            "EH0", "EH1", "ER0", "ER1", "EY0", "EY1",
                            "F", "G", "HH",
                            "IH0", "IH1", "IY0", "IY1",
                            "JH", "K", "L", "M", "N", "NG",
                            "OW0", "OW1", "OY0", "OY1",
                            "P", "R", "S", "SH", "T", "TH",
                            "UH0", "UH1", "UW0", "UW1",
                            "V", "W", "Y", "Z", "ZH"]

        self.dictionary = {c: i for i, c in enumerate(self.rdictionary)}

    def __len__(self):
        return len(self.rdictionary)

    def encode(self, text):
        labels = [c.replace('2', '0') for c in self.g2p(text) if c != "'"]
        targets = torch.LongTensor([self.dictionary[phoneme] for phoneme in labels])
        return targets

    def decode(self, targets):
        return [self.rdictionary[token] for token in targets]

In [58]:
class Recognizer(nn.Module):
    def __init__(self, feat_dim=1024, vocab_size=55+1):
        super().__init__()
        self.classifier = nn.Linear(feat_dim, vocab_size)

    def forward(self, features):
        features = self.classifier(features)
        return features.log_softmax(dim=-1)

In [59]:
vocab = Vocabulary()
encoder = Encoder()
recognizer = Recognizer()

In [60]:
ckpt = torch.load('lstm_p3_360+500.pt', map_location='cpu')
encoder.load_state_dict(ckpt['encoder'])
recognizer.load_state_dict(ckpt['recognizer'])

<All keys matched successfully>

In [63]:
audio_frames, text, ids = LibriSpeech()[0]
phonemes = vocab.encode(text)

In [64]:
ids

(1272, 128104, 0)

In [31]:
features = encoder(audio_frames)
outputs = recognizer.forward(features) # (T, 55+1)

In [32]:
text

'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'

In [39]:
print('real phonemes:       ', ''.join(vocab.decode(phonemes)))
print(f'predicted phonemes: ', ''.join(vocab.decode(torch.argmax(outputs, dim=1))))

real phonemes:        MIH1STER0 KWIH1LTER0 IH1Z DHAH0 AH0PAA1SAH0L AH1V DHAH0 MIH1DAH0L KLAE1SAH0Z AH0ND WIY1 AA1R GLAE1D TUW1 WEH1LKAH0M HHIH1Z GAA1SPAH0L
predicted phonemes:  εεεεεεεεεεεεεεεεεεεεMIH1STER0  KRIH1LTTεER0ε  IH1Z DHAH0  εAH0εPPεAA1SεAH0AH0LL  AH1V DHAH0 MIH1DAH0LL KLLAE1εSεAH0εεZεε  εAH0ND WIH1εRR GLLAE1εDDε TεUW1WWEH1LεKεAH0M  HHHHIH1Z   GεAA1εSPPεAH0Lεεεεεεεε


In [55]:
outputs.size()[0]

147

In [41]:
output_dir = 'result'

In [76]:
utt_i=0
for audio_frames, text, ids in tqdm(LibriSpeech()):
    try:
        features = encoder(audio_frames)
        outputs = recognizer.forward(features)

        tg = tgio.Textgrid()
        tg.minTimestamp = 0
        tg.maxTimestamp = outputs.size()[0] / 25

        tier_name = 'phones'
        phones_tier = tgio.IntervalTier(tier_name, [], minT=0, maxT=tg.maxTimestamp)

        intervals = []
        output_tokens = torch.argmax(outputs, dim=1)
        decoded_output_tokens = vocab.decode(output_tokens)
        prev_token, prev_start = None, 0

        for i, token in enumerate(decoded_output_tokens):
            if prev_token != token and prev_token:
                intervals.append(Interval(prevStart, i/25, prev_token))
                prev_token = token
                prevStart = i/25
            elif not prev_token:
                prev_token = token
        if prev_token:
            intervals.append(Interval(prevStart, tg.maxTimestamp, prev_token))

        new_phonemes_tier = phones_tier.new(entries=intervals)
        tg.addTier(new_phonemes_tier)

        tg.save(f'{output_dir}/{ids[0]}_{ids[1]}_{ids[2]}.TextGrid',
                includeBlankSpaces=True,
                format='long_textgrid',
                reportingMode='error')
    except Exception as e:
        print(e)
        print(utt_i, ids)
    utt_i += 1

 15%|█▍        | 401/2703 [04:03<15:04,  2.55it/s]

The start time of an interval (2.76) cannot occur after its end time (2.75)
400 (1988, 24833, 22)


 26%|██▌       | 700/2703 [07:22<10:52,  3.07it/s]

The start time of an interval (2.88) cannot occur after its end time (2.88)
699 (2277, 149896, 28)


 27%|██▋       | 742/2703 [07:41<21:35,  1.51it/s]

The start time of an interval (3.44) cannot occur after its end time (3.43)
741 (2277, 149897, 35)


 28%|██▊       | 757/2703 [07:51<21:26,  1.51it/s]

The start time of an interval (1.96) cannot occur after its end time (1.96)
756 (2412, 153947, 12)


 28%|██▊       | 767/2703 [08:00<29:47,  1.08it/s]

The start time of an interval (3.12) cannot occur after its end time (3.12)
766 (2412, 153948, 5)


 30%|██▉       | 805/2703 [08:24<14:12,  2.23it/s]

The start time of an interval (4.24) cannot occur after its end time (4.24)
804 (2428, 83699, 2)


 30%|███       | 816/2703 [08:30<13:10,  2.39it/s]

The start time of an interval (3.28) cannot occur after its end time (3.27)
815 (2428, 83699, 13)


 32%|███▏      | 866/2703 [08:52<13:55,  2.20it/s]

The start time of an interval (5.52) cannot occur after its end time (5.51)
865 (2428, 83705, 20)


 32%|███▏      | 872/2703 [08:55<17:13,  1.77it/s]

The start time of an interval (10.2) cannot occur after its end time (10.19)
871 (2428, 83705, 26)


 37%|███▋      | 990/2703 [09:59<16:27,  1.73it/s]

The start time of an interval (6.32) cannot occur after its end time (6.32)
989 (2803, 154328, 10)


 42%|████▏     | 1140/2703 [12:00<12:36,  2.07it/s]

The start time of an interval (2.72) cannot occur after its end time (2.71)
1139 (3081, 166546, 33)


 44%|████▍     | 1183/2703 [12:17<09:09,  2.77it/s]

The start time of an interval (1.92) cannot occur after its end time (1.92)
1182 (3081, 166546, 76)


 47%|████▋     | 1275/2703 [13:27<20:01,  1.19it/s]

The start time of an interval (7.44) cannot occur after its end time (7.44)
1274 (3536, 23268, 29)


 48%|████▊     | 1286/2703 [13:38<14:27,  1.63it/s]

The start time of an interval (2.0) cannot occur after its end time (1.99)
1285 (3536, 8226, 9)


 48%|████▊     | 1287/2703 [13:38<15:44,  1.50it/s]

The start time of an interval (9.48) cannot occur after its end time (9.48)
1286 (3536, 8226, 10)


 49%|████▉     | 1335/2703 [14:19<28:51,  1.27s/it]

The start time of an interval (7.52) cannot occur after its end time (7.51)
1334 (3576, 138058, 25)


 50%|████▉     | 1341/2703 [14:24<16:44,  1.36it/s]

The start time of an interval (2.8) cannot occur after its end time (2.79)
1340 (3576, 138058, 31)


 51%|█████     | 1375/2703 [14:46<08:33,  2.59it/s]

The start time of an interval (5.64) cannot occur after its end time (5.64)
1374 (3752, 4943, 24)


 51%|█████▏    | 1390/2703 [14:50<05:44,  3.81it/s]

The start time of an interval (1.92) cannot occur after its end time (1.92)
1389 (3752, 4944, 8)


 52%|█████▏    | 1397/2703 [14:53<08:34,  2.54it/s]

The start time of an interval (4.12) cannot occur after its end time (4.12)
1396 (3752, 4944, 15)


 52%|█████▏    | 1419/2703 [15:05<08:25,  2.54it/s]

The start time of an interval (2.56) cannot occur after its end time (2.56)
1418 (3752, 4944, 37)


 53%|█████▎    | 1433/2703 [15:10<07:14,  2.92it/s]

The start time of an interval (3.52) cannot occur after its end time (3.52)
1432 (3752, 4944, 51)


 53%|█████▎    | 1439/2703 [15:13<08:30,  2.48it/s]

The start time of an interval (2.52) cannot occur after its end time (2.52)
1438 (3752, 4944, 57)


 54%|█████▎    | 1449/2703 [15:17<08:09,  2.56it/s]

The start time of an interval (2.68) cannot occur after its end time (2.68)
1448 (3752, 4944, 67)


 54%|█████▍    | 1459/2703 [15:25<14:10,  1.46it/s]

The start time of an interval (5.48) cannot occur after its end time (5.48)
1458 (3853, 163249, 7)


 54%|█████▍    | 1467/2703 [15:30<13:26,  1.53it/s]

The start time of an interval (2.28) cannot occur after its end time (2.28)
1466 (3853, 163249, 15)


 55%|█████▍    | 1484/2703 [15:45<21:22,  1.05s/it]

The start time of an interval (5.32) cannot occur after its end time (5.32)
1483 (3853, 163249, 32)


 59%|█████▉    | 1595/2703 [18:52<13:27,  1.37it/s]  

The start time of an interval (1.92) cannot occur after its end time (1.92)
1594 (5338, 284437, 25)


 59%|█████▉    | 1599/2703 [18:56<15:28,  1.19it/s]

The start time of an interval (7.56) cannot occur after its end time (7.56)
1598 (5338, 284437, 29)


 60%|██████    | 1630/2703 [19:28<16:00,  1.12it/s]

The start time of an interval (4.32) cannot occur after its end time (4.32)
1629 (5536, 43359, 6)


 61%|██████    | 1643/2703 [19:41<15:15,  1.16it/s]

The start time of an interval (2.8) cannot occur after its end time (2.8)
1642 (5536, 43363, 0)


 62%|██████▏   | 1667/2703 [20:11<16:06,  1.07it/s]

The start time of an interval (3.68) cannot occur after its end time (3.68)
1666 (5694, 64025, 4)


 63%|██████▎   | 1695/2703 [20:44<15:40,  1.07it/s]

The start time of an interval (3.0) cannot occur after its end time (3.0)
1694 (5694, 64029, 8)


 63%|██████▎   | 1705/2703 [20:53<14:07,  1.18it/s]

The start time of an interval (2.48) cannot occur after its end time (2.47)
1704 (5694, 64029, 18)


 63%|██████▎   | 1706/2703 [20:54<13:14,  1.25it/s]

The start time of an interval (3.4) cannot occur after its end time (3.4)
1705 (5694, 64029, 19)


 69%|██████▉   | 1861/2703 [23:00<13:15,  1.06it/s]

The start time of an interval (6.0) cannot occur after its end time (6.0)
1860 (6241, 61946, 7)


 69%|██████▉   | 1875/2703 [23:15<18:49,  1.36s/it]

The start time of an interval (4.44) cannot occur after its end time (4.44)
1874 (6241, 61946, 21)


 71%|███████   | 1918/2703 [24:10<14:08,  1.08s/it]

The start time of an interval (2.0) cannot occur after its end time (2.0)
1917 (6295, 244435, 14)


 71%|███████▏  | 1927/2703 [24:20<12:29,  1.04it/s]

The start time of an interval (3.4) cannot occur after its end time (3.39)
1926 (6295, 244435, 23)


 73%|███████▎  | 1986/2703 [25:32<08:21,  1.43it/s]

The start time of an interval (4.48) cannot occur after its end time (4.47)
1985 (6313, 66125, 8)


 74%|███████▎  | 1991/2703 [25:36<07:26,  1.59it/s]

The start time of an interval (2.04) cannot occur after its end time (2.04)
1990 (6313, 66125, 13)


 75%|███████▍  | 2015/2703 [25:56<07:49,  1.47it/s]

The start time of an interval (4.76) cannot occur after its end time (4.76)
2014 (6313, 66129, 9)


 75%|███████▍  | 2017/2703 [25:57<07:11,  1.59it/s]

The start time of an interval (2.2) cannot occur after its end time (2.2)
2016 (6313, 66129, 11)


 76%|███████▌  | 2049/2703 [26:24<06:42,  1.62it/s]

The start time of an interval (3.68) cannot occur after its end time (3.68)
2048 (6313, 76958, 7)


 76%|███████▌  | 2061/2703 [26:35<07:49,  1.37it/s]

The start time of an interval (3.08) cannot occur after its end time (3.08)
2060 (6313, 76958, 19)


 80%|███████▉  | 2161/2703 [28:31<05:50,  1.54it/s]

The start time of an interval (2.12) cannot occur after its end time (2.11)
2160 (6345, 93302, 11)


 80%|████████  | 2164/2703 [28:33<06:21,  1.41it/s]

The start time of an interval (6.08) cannot occur after its end time (6.08)
2163 (6345, 93302, 14)


 80%|████████  | 2172/2703 [28:39<04:42,  1.88it/s]

The start time of an interval (2.56) cannot occur after its end time (2.55)
2171 (6345, 93302, 22)


 80%|████████  | 2173/2703 [28:39<05:27,  1.62it/s]

The start time of an interval (5.52) cannot occur after its end time (5.51)
2172 (6345, 93302, 23)


 80%|████████  | 2174/2703 [28:40<05:16,  1.67it/s]

The start time of an interval (2.52) cannot occur after its end time (2.52)
2173 (6345, 93302, 24)


 81%|████████  | 2179/2703 [28:46<12:13,  1.40s/it]

The start time of an interval (13.12) cannot occur after its end time (13.12)
2178 (6345, 93302, 29)


 81%|████████  | 2187/2703 [29:00<11:21,  1.32s/it]

The start time of an interval (5.12) cannot occur after its end time (5.11)
2186 (6345, 93306, 7)


 81%|████████  | 2192/2703 [29:05<09:04,  1.07s/it]

The start time of an interval (2.56) cannot occur after its end time (2.55)
2191 (6345, 93306, 12)


 83%|████████▎ | 2256/2703 [29:45<03:09,  2.36it/s]

The start time of an interval (7.76) cannot occur after its end time (7.76)
2255 (652, 130726, 29)


 83%|████████▎ | 2257/2703 [29:45<02:43,  2.73it/s]

The start time of an interval (2.84) cannot occur after its end time (2.83)
2256 (652, 130726, 30)


 84%|████████▎ | 2262/2703 [29:48<03:51,  1.90it/s]

The start time of an interval (5.08) cannot occur after its end time (5.07)
2261 (652, 130726, 35)


 91%|█████████ | 2448/2703 [31:10<01:51,  2.28it/s]

The start time of an interval (1.92) cannot occur after its end time (1.92)
2447 (7976, 105575, 17)


 93%|█████████▎| 2503/2703 [31:34<01:42,  1.95it/s]

The start time of an interval (3.12) cannot occur after its end time (3.12)
2502 (7976, 110523, 16)


 96%|█████████▌| 2592/2703 [32:15<00:50,  2.20it/s]

The start time of an interval (7.0) cannot occur after its end time (7.0)
2591 (84, 121123, 8)


 96%|█████████▌| 2595/2703 [32:16<00:40,  2.64it/s]

The start time of an interval (3.24) cannot occur after its end time (3.23)
2594 (84, 121123, 11)


 97%|█████████▋| 2614/2703 [32:23<00:43,  2.04it/s]

The start time of an interval (7.96) cannot occur after its end time (7.96)
2613 (84, 121550, 1)


 97%|█████████▋| 2617/2703 [32:25<00:46,  1.83it/s]

The start time of an interval (7.96) cannot occur after its end time (7.95)
2616 (84, 121550, 4)


 97%|█████████▋| 2635/2703 [32:36<00:39,  1.74it/s]

The start time of an interval (7.64) cannot occur after its end time (7.63)
2634 (84, 121550, 22)


 98%|█████████▊| 2655/2703 [32:48<00:32,  1.46it/s]

The start time of an interval (5.36) cannot occur after its end time (5.35)
2654 (8842, 302196, 6)


 99%|█████████▉| 2685/2703 [33:04<00:08,  2.11it/s]

The start time of an interval (3.68) cannot occur after its end time (3.68)
2684 (8842, 302203, 7)


100%|██████████| 2703/2703 [33:16<00:00,  1.35it/s]
