In [1]:
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

from g2p_en import G2p
import re

from basicOperations.manifoldOperations import matrixDistance, frechetMean
import torch.nn.utils as utils

from rnn import euclideanRnn
import math

import pickle
import Levenshtein
import os

import matplotlib.pyplot as plt

In [None]:
"""
Proof for table 1, figure 2, and figure 3.
"""

In [None]:
"""
Train LARGE-VOCAB EMG-to-phoneme conversion.

For description of the data, please see largeVocabDataVisualization.ipynb

Unlike data SMALL-VOCAB, there are no timestamps between words within a sentence. 

Given a sentence, you decode it fully using CTC loss. The pipeline resembles standard speech-to-text (ASR) techniques.
"""

"""https://pypi.org/project/Levenshtein/ - install this Lev distance."""

In [2]:
"""
Open Data.
"""

with open("DATA/dataLargeVocab.pkl", "rb") as file:
    DATA = pickle.load(file)

with open("DATA/labelsLargeVocab.pkl", "rb") as file:
    LABELS = pickle.load(file)

In [3]:
"""
English phoneme definitions.
"""

PHONE_DEF = ['AO', 'OY', 'DH', 'ZH', 'SH', 'CH', 'UH', 'NG', 'IY', 'AA', 'W', 'S', 'IH', 'K', 'EY', 'JH', 'Y', 'N', 'OW', 'M', 'P', 'T', 'B', 'AY', 'UW', 'R', 'G', 'EH', 'Z', 'TH', 'AW', 
             'HH', 'AH', 'AE', 'L', 'ER', 'F', 'V', 'D', '<sp>', 'SIL']

def phoneToId(p):
    return PHONE_DEF.index(p)

g2p = G2p()

In [4]:
"""
Phonemize the sentences.
"""

phonemizedSentences = []

for i in range(len(LABELS)):
    phones = []
    for p in g2p(LABELS[i]): 
        p = re.sub(r'[0-9]', '', p)   
        if re.match(r'[A-Z]+', p) or p == " ": 
            if p == " ":
                phones.append("<sp>")
            else:
                phones.append(p)
    phonemizedSentences.append(phones)

In [5]:
"""
Convert phone-to-indices using look-up dictionary PHONE_DEF.
"""

phoneIndexedSentences = []
for i in range(len(phonemizedSentences)):
    current = phonemizedSentences[i]
    phoneID = []
    for j in range(len(current)):
        phoneID.append(phoneToId(current[j]))
    phoneIndexedSentences.append(phoneID)

In [6]:
"""
Pad the phone transcribed sentences to a common length (to be used with CTC loss).
"""

phonemizedLabels = np.zeros((len(phoneIndexedSentences), 76)) - 1
for i in range(len(phoneIndexedSentences)):
    phonemizedLabels[i, 0:len(phoneIndexedSentences[i])] = phoneIndexedSentences[i]

labelLengths = np.zeros((len(phoneIndexedSentences)))
for i in range(len(phoneIndexedSentences)):
    labelLengths[i] = len(phoneIndexedSentences[i])

In [7]:
"""
z-normalize the data along the time dimension.
"""

normDATA = []
for i in range(len(DATA)):
    Mean = np.mean(DATA[i], axis = -1)
    Std = np.std(DATA[i], axis = -1)
    normDATA.append((DATA[i] - Mean[..., np.newaxis])/Std[..., np.newaxis])

In [8]:
"""
Slice the matrices into 50ms segments with a step size of 20ms.
"""

slicedMatrices = []
for j in range(len(normDATA)):
    collect = []
    stepSize = 100 
    windowSize = 125
    dataLength = normDATA[j].shape[1]
    numIters = (dataLength - windowSize) // stepSize + 1
       
    for i in range(numIters):
        where = i * stepSize + windowSize
        start = where - windowSize
        End = where + windowSize
        temp = 1/(2 * windowSize) * (normDATA[j][:, start:End] @ normDATA[j][:, start:End].T)
        collect.append(0.9 * temp + 0.1 * np.trace(temp) * np.eye(31))
    slicedMatrices.append(collect)

In [9]:
""" Diag = TRUE or FALSE. Raw SPD matrices or approximately diagonalized?"""
DIAG = True

In [10]:
"""
Approximately diagonalize the matrices using Frechet mean.
"""

MEAN = np.load("DATA/ckptsLargeVocab/frechetMeanLargeVocab.npy")
eigenvalues, eigenvectors = np.linalg.eig(MEAN)

identityMatrix = np.eye(31)
afterMatrices = np.tile(identityMatrix, (len(slicedMatrices), 409, 1, 1)) 
inputLengths = np.zeros((len(slicedMatrices)))
for i in range(len(slicedMatrices)):
    for j in range(len(slicedMatrices[i])):
        if DIAG:
            temp = eigenvectors.T @ slicedMatrices[i][j] @ eigenvectors
        else: 
            temp = slicedMatrices[i][j]
        afterMatrices[i, j] = temp
    inputLengths[i] = len(slicedMatrices[i])

In [11]:
class BaseDataset(Dataset):
    def __init__(self, data, labels, inputLength, targetLength):
        self.data = data 
        self.labels = labels
        self.targetLength = targetLength
        self.inputLength = inputLength

    def __getitem__(self, index):
        inputSeq = self.data[index].astype('float32')  
        targetSeq = self.labels[index]
        inputLength = int(self.inputLength[index])
        targetLength = int(self.targetLength[index])
        return inputSeq, targetSeq, inputLength, targetLength

    def __len__(self):
        return len(self.data)

In [12]:
"""
Train-validation-test split.
"""

trainFeatures = afterMatrices[:8000]
trainLabels = phonemizedLabels[:8000]
trainLabelLengths = labelLengths[:8000]
trainInputLengths = inputLengths[:8000]

valFeatures = afterMatrices[8000:9000]
valLabels = phonemizedLabels[8000:9000]
valLabelLengths = labelLengths[8000:9000]
valInputLengths = inputLengths[8000:9000]

testFeatures = afterMatrices[9000:]
testLabels = phonemizedLabels[9000:]
testLabelLengths = labelLengths[9000:]
testInputLengths = inputLengths[9000:]

In [13]:
trainDataset = BaseDataset(trainFeatures, trainLabels, trainInputLengths, trainLabelLengths)
valDataset = BaseDataset(valFeatures, valLabels, valInputLengths, valLabelLengths)
testDataset = BaseDataset(testFeatures, testLabels, testInputLengths, testLabelLengths)

trainDataloader = DataLoader(trainDataset, batch_size = 32, shuffle = True)
valDataloader = DataLoader(valDataset, batch_size = 32, shuffle = False)
testDataloader = DataLoader(testDataset, batch_size = 32, shuffle = False)

In [16]:
"""
To replicate the PER (phoneme error rate) for various model sizes and layers, change the variable here:
euclideanRnn.RnnNet(41, modelHiddenDimension = 25, device, numLayers = 3).to(device)
"""

dev = "cuda:0"
device = torch.device(dev)

numberEpochs = 100

model = euclideanRnn.RnnNet(41, 25, device, numLayers = 3).to(device)
numParams = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(numParams)
lossFunction = nn.CTCLoss(blank = 40, zero_infinity = True)
rnnOptimizer = optim.Adam(model.parameters(), lr = 0.001, weight_decay = 1e-3)

6348591


In [17]:
def testOperation(model, device, testLoader, Loss):
    model.eval()
    totalLoss = 0
    Outputs = []
    with torch.no_grad():
        for inputs, targets, inputLengths, targetLengths in testLoader:
            inputs, targets = inputs.to(device), targets.to(device)
            inputLengths, targetLengths = inputLengths.to(device), targetLengths.to(device)
            
            outputs = model(inputs, inputLengths.cpu()) 

            loss = Loss(outputs, targets, inputLengths, targetLengths)
            totalLoss += loss.item()
            Outputs.append(outputs.transpose(0, 1))

    return Outputs, totalLoss / len(testLoader)

In [18]:
modelWeight = torch.load("DATA/ckptsLargeVocab/ckptLargeVocab.pt", weights_only = True)
model.load_state_dict(modelWeight)
output, testLoss = testOperation(model, device, testDataloader, lossFunction)

print("TEST LOSS: ", testLoss)

TEST LOSS:  1.8337221510948674


In [19]:
outs = []
for o in output:
    for oo in o:
        outs.append(oo)

In [20]:
import math
import numpy as np
import kenlm

In [21]:
def ctcBeamSearchPhoneLM(
    logProbs, length, lm, phoneDef, beamSize = 20, lmWeight = 1.2,
    insertionBonus = 0.0, topk = None, allowDoubles = True, blankPhone = "SIL"
):
    lp = np.asarray(logProbs)
    Ttotal, V = lp.shape
    T = Ttotal if length is None else int(min(length, Ttotal))
    assert len(phoneDef) == V
    LN10 = math.log(10.0)

    blank = phoneDef.index(blankPhone)
    idx2tok = [None if ph == blankPhone else ph for ph in phoneDef]

    def lmBos():
        if lm is None: return None
        s = kenlm.State(); lm.BeginSentenceWrite(s); return s

    def lmAdv(st, tok):
        if lm is None or tok is None: return st, 0.0
        if tok not in lm and "<unk>" in lm: tok = "<unk>"
        if tok not in lm: return st, 0.0
        ns = kenlm.State()
        inc = lm.BaseScore(st, tok, ns) 
        return ns, inc * LN10

    beams = {(): (0.0, -np.inf, lmBos())} 

    def add(store, seq, addPb, addPnb, st):
        if seq in store:
            pb, pnb, cur = store[seq]
            if addPb  != -np.inf: pb  = np.logaddexp(pb,  addPb)
            if addPnb != -np.inf: pnb = np.logaddexp(pnb, addPnb)
            if cur is None and st is not None: cur = st
            store[seq] = (pb, pnb, cur)
        else:
            store[seq] = (addPb, addPnb, st)

    for t in range(T):
        row = lp[t]
        new = {}

        if topk is not None and topk < V:
            cand = np.argpartition(row, -topk)[-topk:]
            if blank not in cand:
                worst = cand[np.argmin(row[cand])]
                cand[cand == worst] = blank
        else:
            cand = range(V)

        for seq, (pb, pnb, st) in beams.items():
            add(new, seq, np.logaddexp(pb, pnb) + row[blank], -np.inf, st)

            last = seq[-1] if seq else None

            for c in cand:
                if c == blank: continue
                pC = row[c]

                if c == last:
                    add(new, seq, -np.inf, pnb + pC, st)
                    if allowDoubles:
                        tok = idx2tok[c]
                        ns, inc = lmAdv(st, tok)
                        add(new, seq + (c,), -np.inf, pb + pC + lmWeight*inc + insertionBonus, ns)
                else:
                    tok = idx2tok[c]
                    ns, inc = lmAdv(st, tok)
                    add(new, seq + (c,), -np.inf, np.logaddexp(pb, pnb) + pC + lmWeight*inc + insertionBonus, ns)

        if len(new) > beamSize:
            items = sorted(new.items(),
                           key=lambda kv: np.logaddexp(kv[1][0], kv[1][1]),
                           reverse=True)[:beamSize]
            beams = dict(items)
        else:
            beams = new

    bestSeq = max(beams.items(), key=lambda kv: np.logaddexp(kv[1][0], kv[1][1]))[0]
    return bestSeq


In [None]:
"""6gram.arpa is a phoneme level LM created using wiki-text-103 corpora. The entire corpora was rewritten on a phone level
with <sp> denoting the blank space. 

bash: /mnt/dataDrive/kenLM/kenlm/build/bin/lmplz -o 6 --discount_fallback \
  < wikitext-103.phone.txt > 6gram.arpa

Use this LM with beamsearch. This is very similar to https://github.com/facebookresearch/emg2qwerty.


Then, we use a different LM for phone-to-word mapping.
  """

In [23]:
lm = kenlm.Model("DATA/ckptsLargeVocab/6gram.arpa") 

Loading the LM will be faster if you build a binary file.
Reading /mnt/dataDrive/emgFullCorpora/toUpload/DATA/ckptsLargeVocab/6gram.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************


In [None]:
LEVS = []
decodedOut = []
for i in range(1970):
    decodedSymbols = ctcBeamSearchPhoneLM(
    logProbs = outs[i].cpu().numpy(),    
    length = int(testInputLengths[i]),
    lm = lm,
    phoneDef = PHONE_DEF,
    beamSize = 5,
    lmWeight = 0.4,
    insertionBonus = 1.2,
    topk = None,
    allowDoubles = True,
    blankPhone = "SIL",
) 
    phoneOut = []
    for i in range(len(decodedSymbols)):
        phoneOut.append(PHONE_DEF[decodedSymbols[i]])
    decodedOut.append(phoneOut)

In [25]:
def findClosestTranscription(decodedTranscript, phoneticTranscription):
    
    dist = Levenshtein.distance(decodedTranscript, phoneticTranscription)

    return dist

In [26]:
levs = []
phoneLENGTHS = []
for i in range(len(decodedOut)):
    phoneLENGTHS.append(len(phonemizedSentences[9000 + i]))
    levs.append(findClosestTranscription(decodedOut[i], phonemizedSentences[9000 + i]))
LEVS.append(np.mean(levs))

In [27]:
print("Mean length of sentences: ", np.mean(phoneLENGTHS))
print("Mean phoneme error rate (insertion errors + deletion errors + substitution errors): ", np.mean(levs))
print("Percent phoneme error: ", np.mean(levs)/np.mean(phoneLENGTHS))

Mean length of sentences:  24.543654822335025
Mean phoneme error rate (insertion errors + deletion errors + substitution errors):  11.10253807106599
Percent phoneme error:  0.45235879299290604


In [28]:
indices = np.argsort(np.array(levs)/np.array(phoneLENGTHS))
print(indices[:100])

[ 939 1800   90 1771  540   89  835  155 1315  989 1817  881 1214 1901
 1920 1212 1538  988 1374  170  536 1550  785  790  797 1222 1467  779
  188 1128 1035 1382 1130 1504  839   42 1962 1027 1720  738 1908  602
  693  841 1381  741 1119 1798 1564 1924 1447  197  371  173  815  289
 1546 1231  778  740  894  674 1265  154  892  810 1026 1056 1385 1126
 1116  685 1686 1473 1037 1807  354 1135  698  713 1961  706  476   78
    8 1892 1244  620 1080  428 1185 1620 1619 1751  762  198  623 1121
 1047  435]


In [29]:
"""
Visualize decoded sentences.
"""

which = 939
print("Decoded phoneme sequence: ", decodedOut[which])
print("Ground truth phoneme sequence: ", phonemizedSentences[9000 + which])
print("Ground truth label: ", LABELS[9000 + which])
print(" ")
print("Levenshtein distance between decoded and ground truth sequence: ", Levenshtein.distance(decodedOut[which], phonemizedSentences[9000 + which]))
print("Length of ground truth sequence: ", len(phonemizedSentences[9000 + which]))

Decoded phoneme sequence:  ['DH', 'AH', '<sp>', 'S', 'EY', 'M', '<sp>', 'S', 'IH', 'T', 'IY', '<sp>', 'AO', 'R', '<sp>', 'K', 'AW', 'N', 'T', 'IY']
Ground truth phoneme sequence:  ['DH', 'AH', '<sp>', 'S', 'EY', 'M', '<sp>', 'S', 'IH', 'T', 'IY', '<sp>', 'AO', 'R', '<sp>', 'K', 'AW', 'N', 'T', 'IY']
Ground truth label:  the same city or county
 
Levenshtein distance between decoded and ground truth sequence:  0
Length of ground truth sequence:  20


In [None]:
"""
Do word modeling. It is a very simple 3-gram model.
The model was trained on Librispeech-100 sentences with kneLM using the command: lmplz -o 3 < train.txt > 3gram.arpa

1) We mainly rely on matching phoneme segments to word pronunciations in a lexicon dectionary containing all possible unique words.
2) kneLM language model is only used to disambiguate and choose coherent word sequences.
"""

In [31]:
import kenlm
from rapidfuzz.distance import Levenshtein
from itertools import product
from heapq import heappush, heappop
from jiwer import wer
from typing import List, Dict, Tuple

In [32]:
lm = kenlm.Model("DATA/ckptsLargeVocab/3gramLibri.arpa")

Loading the LM will be faster if you build a binary file.
Reading /mnt/dataDrive/emgFullCorpora/toUpload/DATA/ckptsLargeVocab/3gramLibri.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************


In [33]:
with open("DATA/ckptsLargeVocab/libri100.txt", "r") as f:
    lines = [line.strip() for line in f]

uniqueWords = set()
for line in lines:
    words = re.findall(r"\b\w+\b", line.lower())
    uniqueWords.update(words)

In [34]:
print(len(uniqueWords))

34541


In [35]:
print(len(lines))

37663


In [36]:
phonemizedWords = {}

for word in sorted(uniqueWords):
    phones = []
    for p in g2p(word): 
        p = re.sub(r'[0-9]', '', p)   
        if re.match(r'[A-Z]+', p) or p == " ": 
            phones.append(p)
    phonemizedWords[word] = phones

In [38]:
def fuzzyMatch(
    phonemeSegment: List[str],
    lexicon: Dict[str, List[str]],
    maxDist: int = 8,
    topK: int = 20
) -> List[Tuple[str, float]]:
    matches = []
    for word, phonemes in lexicon.items():
        dist = Levenshtein.distance(phonemeSegment, phonemes)
        if dist <= maxDist:
            normDist = dist / max(len(phonemeSegment), len(phonemes))
            matches.append((word, normDist))
    return sorted(matches, key=lambda x: x[1])[:topK]

def phoneme2wordsHardSegmentation(
    phonemes: List[str],
    lexicon: Dict[str, List[str]],
    lm,
    blankToken: str = '<sp>',
    lambdaLM: float = 0.17,
    lambdaDist: float = 0.83,
    minLen: int = 2
) -> Tuple[List[str], float]:
   
    segments = []
    buffer = []
    for p in phonemes:
        if p == blankToken:
            if len(buffer) >= minLen or buffer == ["AH"] or buffer == ['AY']:
                segments.append(buffer)
            buffer = []
        else:
            buffer.append(p)
    if len(buffer) >= minLen or buffer == ["AH"] or buffer == ['AY']:
        segments.append(buffer)

    decodedWords = []
    normDistTotal = 0.0
    for segment in segments:
        matches = fuzzyMatch(segment, lexicon, maxDist = int(len(segment)) + 1)
        if not matches:
            decodedWords.append("<UNK>")
            continue

        
        if decodedWords:
            prefix = " ".join(decodedWords)
            prefixScore = lm.score(prefix, bos=True, eos=False)
        else:
            prefix = ""
            prefixScore = 0.0

        scored = []
        for word, normDist in matches:
            if prefix:
                candSentence = prefix + " " + word
            else:
                candSentence = word
            lmGain = lm.score(candSentence, bos=True, eos=False) - prefixScore

            score = lambdaLM * lmGain - lambdaDist * normDist
            scored.append((word, score, normDist))

        bestWord, bestScore, bestNormDist = max(scored, key=lambda x: x[1])
        decodedWords.append(bestWord)
        normDistTotal += bestNormDist

    if decodedWords:
        finalSentence = " ".join(decodedWords)
        finalLMScore = lm.score(finalSentence, bos=True, eos=True)
    else:
        finalSentence = ""
        finalLMScore = 0.0

    totalScore = lambdaLM * finalLMScore - lambdaDist * normDistTotal
    return decodedWords, totalScore

In [39]:
RESULTS = []
for i, phonemeSeq in enumerate(decodedOut):
    words, _ = phoneme2wordsHardSegmentation(phonemeSeq, phonemizedWords, lm)
    RESULTS.append(" ".join(words))

In [40]:
WER = []
for i in range(len(RESULTS)):
    error = wer(LABELS[9000 + i], RESULTS[i])
    WER.append(error)

In [41]:
print("Mean word error rate: ", np.mean(WER))

Mean word error rate:  0.7810801479963916


In [43]:
"""
Visualize WERs.
"""
which = 939
print("Decoded phoneme sequence: ", decodedOut[which])
print("Ground truth phoneme sequence: ", phonemizedSentences[9000 + which])
print(" ")
print("Levenshtein distance between decoded and ground truth sequence: ", Levenshtein.distance(decodedOut[which], phonemizedSentences[9000 + which]))
print("Length of ground truth sequence: ", len(phonemizedSentences[9000 + which]))
print(" ")
print("Decoded sentence: ", RESULTS[which])
print("Original sentence: ", LABELS[9000 + which])

Decoded phoneme sequence:  ['DH', 'AH', '<sp>', 'S', 'EY', 'M', '<sp>', 'S', 'IH', 'T', 'IY', '<sp>', 'AO', 'R', '<sp>', 'K', 'AW', 'N', 'T', 'IY']
Ground truth phoneme sequence:  ['DH', 'AH', '<sp>', 'S', 'EY', 'M', '<sp>', 'S', 'IH', 'T', 'IY', '<sp>', 'AO', 'R', '<sp>', 'K', 'AW', 'N', 'T', 'IY']
 
Levenshtein distance between decoded and ground truth sequence:  0
Length of ground truth sequence:  20
 
Decoded sentence:  the same city or county
Original sentence:  the same city or county
