In [50]:
import torch
import torch.nn as nn
import torchaudio
import numpy as np
from g2p_en import G2p

Author: Maksym Sarana

In [21]:
! pip install g2p-en

Collecting g2p-en
  Downloading g2p_en-2.1.0-py3-none-any.whl (3.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting distance>=0.1.3
  Downloading Distance-0.1.3.tar.gz (180 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m180.3/180.3 kB[0m [31m14.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting inflect>=0.3.1
  Downloading inflect-6.0.4-py3-none-any.whl (34 kB)
Collecting pydantic>=1.9.1
  Downloading pydantic-1.10.8-cp310-cp310-macosx_11_0_arm64.whl (2.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Building wheels for collected packages: distance
  Building wheel for distance (setup.py) ... [?25ldone
[?25h  Created wheel for distance: filename=Distance-0.1.3-py3-none-any.whl size=16257 sha256=0c00c65eaaf786d7624c5b1dab

In [105]:
#Source: https://gist.github.com/proger/a7e820fbfa0181273fdbf2351901d0d8

def make_frames(wav):
    return torchaudio.compliance.kaldi.mfcc(wav)


class LibriSpeech(torch.utils.data.Dataset):
    def __init__(self, url='dev-clean'):
        super().__init__()
        self.librispeech = torchaudio.datasets.LIBRISPEECH('.', url=url, download=True)

    def __len__(self):
        return len(self.librispeech)

    def __getitem__(self, index):
        wav, sr, text, speaker_id, chapter_id, utterance_id = self.librispeech[index]
        return make_frames(wav), text

      
class Encoder(nn.Module):
    def __init__(self, input_dim=13, subsample_dim=128, hidden_dim=1024):
        super().__init__()
        self.subsample = nn.Conv1d(input_dim, subsample_dim, 5, stride=4, padding=3)
        self.lstm = nn.LSTM(subsample_dim, hidden_dim, batch_first=True, num_layers=3, dropout=0.2)

    def subsampled_lengths(self, input_lengths):
        # https://github.com/vdumoulin/conv_arithmetic
        p, k, s = self.subsample.padding[0], self.subsample.kernel_size[0], self.subsample.stride[0]
        o = input_lengths + 2 * p - k
        o = torch.floor(o / s + 1)
        return o.int()

    def forward(self, inputs):
        x = inputs
        x = self.subsample(x.mT).mT
        x = x.relu()
        x, _ = self.lstm(x)
        return x.relu()

class Recognizer(nn.Module):
    def __init__(self, feat_dim=1024, vocab_size=55+1):
        super().__init__()
        self.classifier = nn.Linear(feat_dim, vocab_size)

    def forward(self, features):
        features = self.classifier(features)
        return features.log_softmax(dim=-1)
    
class Vocabulary:
    def __init__(self):
        self.g2p = G2p()

        # http://www.speech.cs.cmu.edu/cgi-bin/cmudict
        self.rdictionary = ["ε", # CTC blank
                            " ",
                            "AA0", "AA1", "AE0", "AE1", "AH0", "AH1", "AO0", "AO1", "AW0", "AW1", "AY0", "AY1",
                            "B", "CH", "D", "DH",
                            "EH0", "EH1", "ER0", "ER1", "EY0", "EY1",
                            "F", "G", "HH",
                            "IH0", "IH1", "IY0", "IY1",
                            "JH", "K", "L", "M", "N", "NG",
                            "OW0", "OW1", "OY0", "OY1",
                            "P", "R", "S", "SH", "T", "TH",
                            "UH0", "UH1", "UW0", "UW1",
                            "V", "W", "Y", "Z", "ZH"]

        self.dictionary = {c: i for i, c in enumerate(self.rdictionary)}

    def __len__(self):
        return len(self.rdictionary)

    def encode(self, text):
        labels = [c.replace('2', '0') for c in self.g2p(text) if c != "'"]
        targets = torch.LongTensor([self.dictionary[phoneme] for phoneme in labels])
        return targets

In [58]:
dataset = LibriSpeech()

In [13]:
! curl https://wilab.org.ua/lstm_p3_360+500.pt -o lstm_p3_360+500.pt

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  247M  100  247M    0     0  3587k      0  0:01:10  0:01:10 --:--:-- 4655k 0:00:16 5734k 3658k      0  0:01:09  0:00:56  0:00:13 4959k


In [99]:
vocab = Vocabulary()
encoder = Encoder()
recognizer = Recognizer()

In [100]:
ckpt = torch.load('lstm_p3_360+500.pt', map_location='cpu')
encoder.load_state_dict(ckpt['encoder'])
recognizer.load_state_dict(ckpt['recognizer'])

<All keys matched successfully>

In [102]:
def split_array(array, separators):
    start_indexes = []
    splits = [[]]
    
    for i, x in enumerate(array):
        if (x in separators) and splits[-1]:
            splits.append([])
            
        if (x not in separators):
            if not splits[-1]:
                start_indexes.append(i)
            splits[-1].append(x)
            
    return splits if splits[-1] else splits[:-1], start_indexes
    
def remove_duplicates(array):
    res = []
    for x in array:
        if res and x == res[-1]:
            continue
        res.append(x)
        
    return res

def get_alignments(audio_frames, frames_per_second=25.0):
    features = encoder(audio_frames)
    outputs = recognizer.forward(features)
    utterance_symbol_indexes = torch.argmax(outputs, dim=1)

    splits, start_indexes = split_array(utterance_symbol_indexes.numpy(), {0, 1})
        
    alignments = []
    for split, start_index in zip(splits, start_indexes):
        alignments.append({
            'start_time': start_index / frames_per_second,
            'end_time': (start_index + len(split)) / frames_per_second,
            'phones': ' '.join([vocab.rdictionary[x] for x in remove_duplicates(split)])
        })
        
    return alignments

In [104]:
get_alignments(dataset[0][0])

[{'start_time': 0.8, 'end_time': 1.0, 'phones': 'M IH1 S T ER0'},
 {'start_time': 1.08, 'end_time': 1.32, 'phones': 'K R IH1 L T'},
 {'start_time': 1.36, 'end_time': 1.4, 'phones': 'ER0'},
 {'start_time': 1.52, 'end_time': 1.6, 'phones': 'IH1 Z'},
 {'start_time': 1.64, 'end_time': 1.72, 'phones': 'DH AH0'},
 {'start_time': 1.84, 'end_time': 1.88, 'phones': 'AH0'},
 {'start_time': 1.92, 'end_time': 2.0, 'phones': 'P'},
 {'start_time': 2.04, 'end_time': 2.12, 'phones': 'AA1 S'},
 {'start_time': 2.16, 'end_time': 2.32, 'phones': 'AH0 L'},
 {'start_time': 2.4, 'end_time': 2.48, 'phones': 'AH1 V'},
 {'start_time': 2.52, 'end_time': 2.6, 'phones': 'DH AH0'},
 {'start_time': 2.64, 'end_time': 2.84, 'phones': 'M IH1 D AH0 L'},
 {'start_time': 2.92, 'end_time': 3.16, 'phones': 'K L AE1 S'},
 {'start_time': 3.2, 'end_time': 3.24, 'phones': 'AH0'},
 {'start_time': 3.32, 'end_time': 3.36, 'phones': 'Z'},
 {'start_time': 3.52, 'end_time': 3.64, 'phones': 'AH0 N D'},
 {'start_time': 3.72, 'end_time'