In [1]:
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

from g2p_en import G2p
import re

from basicOperations.manifoldOperations import matrixDistance, frechetMean
import torch.nn.utils as utils

from rnn import euclideanRnn
import math

import pickle
import Levenshtein
import os

import matplotlib.pyplot as plt

In [None]:
"""
Proof for table 1, figure 2, and figure 3.
"""

In [None]:
"""
Train LARGE-VOCAB EMG-to-phoneme conversion.

For description of the data, please see largeVocabDataVisualization.ipynb

Unlike data SMALL-VOCAB, there are no timestamps between words within a sentence. 

Given a sentence, you decode it fully using CTC loss. The pipeline resembles standard speech-to-text (ASR) techniques.
"""

In [2]:
"""
Open Data.
"""

with open("DATA/dataLargeVocab.pkl", "rb") as file:
    DATA = pickle.load(file)

with open("DATA/labelsLargeVocab.pkl", "rb") as file:
    LABELS = pickle.load(file)

In [3]:
"""
English phoneme definitions.
"""

PHONE_DEF = ['AO', 'OY', 'DH', 'ZH', 'SH', 'CH', 'UH', 'NG', 'IY', 'AA', 'W', 'S', 'IH', 'K', 'EY', 'JH', 'Y', 'N', 'OW', 'M', 'P', 'T', 'B', 'AY', 'UW', 'R', 'G', 'EH', 'Z', 'TH', 'AW', 
             'HH', 'AH', 'AE', 'L', 'ER', 'F', 'V', 'D', ' ', 'SIL']

def phoneToId(p):
    return PHONE_DEF.index(p)

g2p = G2p()

In [4]:
"""
Phonemize the sentences.
"""

phonemizedSentences = []

for i in range(len(LABELS)):
    phones = []
    for p in g2p(LABELS[i]): 
        p = re.sub(r'[0-9]', '', p)   
        if re.match(r'[A-Z]+', p) or p == " ": 
            phones.append(p)
    phonemizedSentences.append(phones)

In [5]:
"""
Convert phone-to-indices using look-up dictionary PHONE_DEF.
"""

phoneIndexedSentences = []
for i in range(len(phonemizedSentences)):
    current = phonemizedSentences[i]
    phoneID = []
    for j in range(len(current)):
        phoneID.append(phoneToId(current[j]))
    phoneIndexedSentences.append(phoneID)

In [6]:
"""
Pad the phone transcribed sentences to a common length (to be used with CTC loss).
"""

phonemizedLabels = np.zeros((len(phoneIndexedSentences), 76)) - 1
for i in range(len(phoneIndexedSentences)):
    phonemizedLabels[i, 0:len(phoneIndexedSentences[i])] = phoneIndexedSentences[i]

labelLengths = np.zeros((len(phoneIndexedSentences)))
for i in range(len(phoneIndexedSentences)):
    labelLengths[i] = len(phoneIndexedSentences[i])

In [7]:
"""
z-normalize the data along the time dimension.
"""

normDATA = []
for i in range(len(DATA)):
    Mean = np.mean(DATA[i], axis = -1)
    Std = np.std(DATA[i], axis = -1)
    normDATA.append((DATA[i] - Mean[..., np.newaxis])/Std[..., np.newaxis])

In [8]:
"""
Slice the matrices into 50ms segments with a step size of 20ms.
"""

slicedMatrices = []
for j in range(len(normDATA)):
    collect = []
    stepSize = 100 
    windowSize = 125
    dataLength = normDATA[j].shape[1]
    numIters = (dataLength - windowSize) // stepSize + 1
       
    for i in range(numIters):
        where = i * stepSize + windowSize
        start = where - windowSize
        End = where + windowSize
        temp = 1/(2 * windowSize) * (normDATA[j][:, start:End] @ normDATA[j][:, start:End].T)
        collect.append(0.9 * temp + 0.1 * np.trace(temp) * np.eye(31))
    slicedMatrices.append(collect)

In [40]:
""" Diag = TRUE or FALSE. Raw SPD matrices or approximately diagonalized?"""
DIAG = True

In [41]:
"""
Approximately diagonalize the matrices using Frechet mean.
"""

MEAN = np.load("DATA/ckptsLargeVocab/frechetMeanLargeVocab.npy")
eigenvalues, eigenvectors = np.linalg.eig(MEAN)

identityMatrix = np.eye(31)
afterMatrices = np.tile(identityMatrix, (len(slicedMatrices), 409, 1, 1)) 
inputLengths = np.zeros((len(slicedMatrices)))
for i in range(len(slicedMatrices)):
    for j in range(len(slicedMatrices[i])):
        if DIAG:
            temp = eigenvectors.T @ slicedMatrices[i][j] @ eigenvectors
        else: 
            temp = slicedMatrices[i][j]
        afterMatrices[i, j] = temp
    inputLengths[i] = len(slicedMatrices[i])

In [42]:
class BaseDataset(Dataset):
    def __init__(self, data, labels, inputLength, targetLength):
        self.data = data 
        self.labels = labels
        self.targetLength = targetLength
        self.inputLength = inputLength

    def __getitem__(self, index):
        inputSeq = self.data[index].astype('float32')  
        targetSeq = self.labels[index]
        inputLength = int(self.inputLength[index])
        targetLength = int(self.targetLength[index])
        return inputSeq, targetSeq, inputLength, targetLength

    def __len__(self):
        return len(self.data)

In [43]:
"""
Train-validation-test split.
"""

trainFeatures = afterMatrices[:8000]
trainLabels = phonemizedLabels[:8000]
trainLabelLengths = labelLengths[:8000]
trainInputLengths = inputLengths[:8000]

valFeatures = afterMatrices[8000:9000]
valLabels = phonemizedLabels[8000:9000]
valLabelLengths = labelLengths[8000:9000]
valInputLengths = inputLengths[8000:9000]

testFeatures = afterMatrices[9000:]
testLabels = phonemizedLabels[9000:]
testLabelLengths = labelLengths[9000:]
testInputLengths = inputLengths[9000:]

In [44]:
trainDataset = BaseDataset(trainFeatures, trainLabels, trainInputLengths, trainLabelLengths)
valDataset = BaseDataset(valFeatures, valLabels, valInputLengths, valLabelLengths)
testDataset = BaseDataset(testFeatures, testLabels, testInputLengths, testLabelLengths)

trainDataloader = DataLoader(trainDataset, batch_size = 32, shuffle = True)
valDataloader = DataLoader(valDataset, batch_size = 32, shuffle = False)
testDataloader = DataLoader(testDataset, batch_size = 32, shuffle = False)

In [45]:
"""
To replicate the PER (phoneme error rate) for various model sizes and layers, change the variable here:
euclideanRnn.RnnNet(41, modelHiddenDimension = 25, device, numLayers = 3).to(device)
"""

dev = "cuda:0"
device = torch.device(dev)

numberEpochs = 100

model = euclideanRnn.RnnNet(41, 25, device, numLayers = 3).to(device)
numParams = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(numParams)
lossFunction = nn.CTCLoss(blank = 40, zero_infinity = True)
rnnOptimizer = optim.Adam(model.parameters(), lr = 0.001, weight_decay = 1e-3)

6348591


In [46]:
def testOperation(model, device, testLoader, Loss):
    model.eval()
    totalLoss = 0
    Outputs = []
    with torch.no_grad():
        for inputs, targets, inputLengths, targetLengths in testLoader:
            inputs, targets = inputs.to(device), targets.to(device)
            inputLengths, targetLengths = inputLengths.to(device), targetLengths.to(device)
            
            outputs = model(inputs, inputLengths.cpu()) 

            loss = Loss(outputs, targets, inputLengths, targetLengths)
            totalLoss += loss.item()
            Outputs.append(outputs.transpose(0, 1))

    return Outputs, totalLoss / len(testLoader)

In [47]:
"""
Simple beam-search algorithm.
"""

def beamSearch(ctcOutput, testInputLength, beamWidth = 5, blank = 40):
    T, V = ctcOutput.shape
    T = int(min(T, testInputLength))
    beams = {(): (0.0, -np.inf)}  

    for t in range(T):
        newBeams = {}

        for seq, (logProbBlank, logProbNonBlank) in beams.items():
            for symbol in range(V):
                prob = ctcOutput[t, symbol]  

                if symbol == blank:
                    newLogProbBlank = np.logaddexp(logProbBlank, logProbNonBlank) + prob
                    prevLogProbBlank, prevLogProbNonBlank = newBeams.get(seq, (-np.inf, -np.inf))
                    newBeams[seq] = (np.logaddexp(prevLogProbBlank, newLogProbBlank), prevLogProbNonBlank)
                else:
                    newSeq = seq + (symbol,) if not seq or seq[-1] != symbol else seq

                    newLogProbNonBlank = (
                        logProbBlank + prob if seq and seq[-1] == symbol
                        else np.logaddexp(logProbBlank, logProbNonBlank) + prob
                    )

                    if newSeq in newBeams:
                        prevLogProbBlank, prevLogProbNonBlank = newBeams[newSeq]
                        newBeams[newSeq] = (
                            prevLogProbBlank, 
                            np.logaddexp(prevLogProbNonBlank, newLogProbNonBlank)
                        )
                    else:
                        newBeams[newSeq] = (-np.inf, newLogProbNonBlank)

        
        beams = dict(sorted(newBeams.items(), key = lambda x: np.logaddexp(*x[1]), reverse = True)[:beamWidth])

    bestSeq = max(beams.items(), key = lambda x: np.logaddexp(*x[1]))[0]
    return bestSeq

def findClosestTranscription(decodedTranscript, phoneticTranscription):
    
    dist = Levenshtein.distance(decodedTranscript, phoneticTranscription)

    return dist

In [49]:
modelWeight = torch.load("DATA/ckptsLargeVocab/ckptLargeVocab.pt", weights_only = True)
model.load_state_dict(modelWeight)
output, testLoss = testOperation(model, device, testDataloader, lossFunction)

print("TEST LOSS: ", testLoss)

TEST LOSS:  1.8337221510948674


In [50]:
outs = []
for o in output:
    for oo in o:
        outs.append(oo)

In [51]:
LEVS = []
decodedOut = []
for i in range(1970):
    decodedSymbols = beamSearch(outs[i].cpu().numpy(), testInputLengths[i]) 
    phoneOut = []
    for i in range(len(decodedSymbols)):
        phoneOut.append(PHONE_DEF[decodedSymbols[i]])
    decodedOut.append(phoneOut)

levs = []
phoneLENGTHS = []
for i in range(len(decodedOut)):
    phoneLENGTHS.append(len(phonemizedSentences[9000 + i]))
    levs.append(findClosestTranscription(decodedOut[i], phonemizedSentences[9000 + i]))
LEVS.append(np.mean(levs))

In [52]:
print("Mean length of sentences: ", np.mean(phoneLENGTHS))
print("Mean phoneme error rate (insertion errors + deletion errors + substitution errors): ", np.mean(levs))
print("Percent phoneme error: ", np.mean(levs)/np.mean(phoneLENGTHS))

Mean length of sentences:  24.543654822335025
Mean phoneme error rate (insertion errors + deletion errors + substitution errors):  11.925380710659898
Percent phoneme error:  0.4858844698144816


In [80]:
"""
Sort the decoded sentences from best-to-worst. Display 100 best decoded sentences.
"""

indices = np.argsort(np.array(levs)/np.array(phoneLENGTHS))
print(indices[:100])

[1381  989 1800 1119 1525 1920 1056  835  170   87  741  536  939 1447
 1214 1126 1548 1538  988  408 1892 1751   55 1185   59  563 1478   42
 1027  237 1962 1768 1771  161  897  522 1562  785  155 1686 1908 1797
  778  790  239  363 1039 1821 1546  182 1467 1116  892  779  428  188
  894 1047   89  740  787  806 1128 1687  815  486  623 1121  738  991
  119 1130  706  693  764 1504 1437  685 1333  197 1798  540  655  236
  666  810 1212 1055 1268  192  703  997 1807 1231 1550  101 1233 1740
  289 1035]


In [81]:
"""
Visualize decoded sentences.
"""

which = 1381
print("Decoded phoneme sequence: ", decodedOut[which])
print("Ground truth phoneme sequence: ", phonemizedSentences[9000 + which])
print("Ground truth label: ", LABELS[9000 + which])
print(" ")
print("Levenshtein distance between decoded and ground truth sequence: ", Levenshtein.distance(decodedOut[which], phonemizedSentences[9000 + which]))
print("Length of ground truth sequence: ", len(phonemizedSentences[9000 + which]))

Decoded phoneme sequence:  ['IH', 'T', ' ', 'W', 'AA', 'Z', ' ', 'P', 'EY', 'T', ' ', 'F', 'AO', 'R']
Ground truth phoneme sequence:  ['IH', 'T', ' ', 'W', 'AA', 'Z', ' ', 'P', 'EY', 'D', ' ', 'F', 'AO', 'R']
Ground truth label:  it was paid for
 
Levenshtein distance between decoded and ground truth sequence:  1
Length of ground truth sequence:  14


In [None]:
"""
Do word modeling. It is a very simple 3-gram model.
The model was trained on the first 9000 sentences in the dataset (i.e., train and validation sets) with kneLM using the command: lmplz -o 3 < train.txt > 3gram.arpa

1) We mainly rely on matching phoneme segments to word pronunciations in a lexicon dectionary containing all possible unique words.
2) kneLM language model is only used to disambiguate and choose coherent word sequences.
"""

In [55]:
import kenlm
from rapidfuzz.distance import Levenshtein
from itertools import product
from heapq import heappush, heappop
from jiwer import wer
from typing import List, Dict, Tuple

In [56]:
lm = kenlm.Model("DATA/ckptsLargeVocab/3gram.arpa")

Loading the LM will be faster if you build a binary file.
Reading /mnt/dataDrive/emgFullCorpora/toUpload/DATA/ckptsLargeVocab/3gram.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************


In [57]:
uniqueWords = []
for label in LABELS:
    words = label.split(" ")
    for word in words:
        if word not in uniqueWords:
            uniqueWords.append(word)

In [58]:
print(len(uniqueWords))
print(uniqueWords)

6516


In [59]:
phonemizedWords = {}

for word in sorted(uniqueWords):
    phones = []
    for p in g2p(word): 
        p = re.sub(r'[0-9]', '', p)   
        if re.match(r'[A-Z]+', p) or p == " ": 
            phones.append(p)
    phonemizedWords[word] = phones

In [65]:
def fuzzyMatch(
    phonemeSegment: List[str],
    lexicon: Dict[str, List[str]],
    maxDist: int = 8,
    topK: int = 20
) -> List[Tuple[str, float]]:
    matches = []
    for word, phonemes in lexicon.items():
        dist = Levenshtein.distance(phonemeSegment, phonemes)
        if dist <= maxDist:
            normDist = dist / max(len(phonemeSegment), len(phonemes))
            matches.append((word, normDist))
    return sorted(matches, key=lambda x: x[1])[:topK]

def phoneme2wordsHardSegmentation(
    phonemes: List[str],
    lexicon: Dict[str, List[str]],
    lm,
    blankToken: str = ' ',
    lambdaLM: float = 0.17,
    lambdaDist: float = 0.83,
    minLen: int = 2
) -> Tuple[List[str], float]:
   
    segments = []
    buffer = []
    for p in phonemes:
        if p == blankToken:
            if len(buffer) >= minLen or buffer == ["AH"] or buffer == ['AY']:
                segments.append(buffer)
            buffer = []
        else:
            buffer.append(p)
    if len(buffer) >= minLen or buffer == ["AH"] or buffer == ['AY']:
        segments.append(buffer)

    decodedWords = []
    normDistTotal = 0.0
    for segment in segments:
        matches = fuzzyMatch(segment, lexicon, maxDist = int(len(segment)) + 1)
        if not matches:
            decodedWords.append("<UNK>")
            continue

        
        scored = []
        for word, normDist in matches:
            
            candidate = decodedWords + [word]
            sentence = " ".join(candidate)
            lmScore = lm.score(sentence, bos = False, eos = False)  
            score = lambdaLM * lmScore - lambdaDist * normDist
            scored.append((word, score, normDist))

        bestWord, bestScore, bestNormDist = max(scored, key = lambda x: x[1])
        decodedWords.append(bestWord)
        normDistTotal += bestNormDist

    finalSentence = " ".join(decodedWords)
    finalLMScore = lm.score(finalSentence, bos = False, eos = False)

    totalScore = lambdaLM * finalLMScore - lambdaDist * normDistTotal
    return decodedWords, totalScore

In [66]:
RESULTS = []
for i, phonemeSeq in enumerate(decodedOut):
    words, _ = phoneme2wordsHardSegmentation(phonemeSeq, phonemizedWords, lm)
    RESULTS.append(" ".join(words))

In [67]:
WER = []
for i in range(len(RESULTS)):
    error = wer(LABELS[9000 + i], RESULTS[i])
    WER.append(error)

In [68]:
print("Mean word error rate: ", np.mean(WER))

Mean word error rate:  0.7814721925001112


In [82]:
"""
Visualize WERs.
"""
which = 1381
print("Decoded phoneme sequence: ", decodedOut[which])
print("Ground truth phoneme sequence: ", phonemizedSentences[9000 + which])
print(" ")
print("Levenshtein distance between decoded and ground truth sequence: ", Levenshtein.distance(decodedOut[which], phonemizedSentences[9000 + which]))
print("Length of ground truth sequence: ", len(phonemizedSentences[9000 + which]))
print(" ")
print("Decoded sentence: ", RESULTS[which])
print("Original sentence: ", LABELS[9000 + which])

Decoded phoneme sequence:  ['IH', 'T', ' ', 'W', 'AA', 'Z', ' ', 'P', 'EY', 'T', ' ', 'F', 'AO', 'R']
Ground truth phoneme sequence:  ['IH', 'T', ' ', 'W', 'AA', 'Z', ' ', 'P', 'EY', 'D', ' ', 'F', 'AO', 'R']
 
Levenshtein distance between decoded and ground truth sequence:  1
Length of ground truth sequence:  14
 
Decoded sentence:  it was pay for
Original sentence:  it was paid for
