In [1]:
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

from g2p_en import G2p
import re

from basicOperations.manifoldOperations import matrixDistance, frechetMean
import torch.nn.utils as utils

from rnn import euclideanRnn
import math

import pickle
import Levenshtein
import os

import matplotlib.pyplot as plt

In [None]:
"""
Proof for table 1, figure 2, and figure 3.
"""

In [None]:
"""
Train LARGE-VOCAB EMG-to-phoneme conversion.

For description of the data, please see largeVocabDataVisualization.ipynb

Unlike data SMALL-VOCAB, there are no timestamps between words within a sentence. 

Given a sentence, you decode it fully using CTC loss. The pipeline resembles standard speech-to-text (ASR) techniques.
"""

In [3]:
"""
Open Data.
"""

with open("DATA/dataLargeVocab.pkl", "rb") as file:
    DATA = pickle.load(file)

with open("DATA/labelsLargeVocab.pkl", "rb") as file:
    LABELS = pickle.load(file)

In [4]:
"""
English phoneme definitions.
"""

PHONE_DEF = ['AO', 'OY', 'DH', 'ZH', 'SH', 'CH', 'UH', 'NG', 'IY', 'AA', 'W', 'S', 'IH', 'K', 'EY', 'JH', 'Y', 'N', 'OW', 'M', 'P', 'T', 'B', 'AY', 'UW', 'R', 'G', 'EH', 'Z', 'TH', 'AW', 
             'HH', 'AH', 'AE', 'L', 'ER', 'F', 'V', 'D', ' ', 'SIL']

def phoneToId(p):
    return PHONE_DEF.index(p)

g2p = G2p()

In [5]:
"""
Phonemize the sentences.
"""

phonemizedSentences = []

for i in range(len(LABELS)):
    phones = []
    for p in g2p(LABELS[i]): 
        p = re.sub(r'[0-9]', '', p)   
        if re.match(r'[A-Z]+', p) or p == " ": 
            phones.append(p)
    phonemizedSentences.append(phones)

In [6]:
"""
Convert phone-to-indices using look-up dictionary PHONE_DEF.
"""

phoneIndexedSentences = []
for i in range(len(phonemizedSentences)):
    current = phonemizedSentences[i]
    phoneID = []
    for j in range(len(current)):
        phoneID.append(phoneToId(current[j]))
    phoneIndexedSentences.append(phoneID)

In [7]:
"""
Pad the phone transcribed sentences to a common length (to be used with CTC loss).
"""

phonemizedLabels = np.zeros((len(phoneIndexedSentences), 76)) - 1
for i in range(len(phoneIndexedSentences)):
    phonemizedLabels[i, 0:len(phoneIndexedSentences[i])] = phoneIndexedSentences[i]

labelLengths = np.zeros((len(phoneIndexedSentences)))
for i in range(len(phoneIndexedSentences)):
    labelLengths[i] = len(phoneIndexedSentences[i])

In [8]:
"""
z-normalize the data along the time dimension.
"""

normDATA = []
for i in range(len(DATA)):
    Mean = np.mean(DATA[i], axis = -1)
    Std = np.std(DATA[i], axis = -1)
    normDATA.append((DATA[i] - Mean[..., np.newaxis])/Std[..., np.newaxis])

In [9]:
"""
Slice the matrices into 50ms segments with a step size of 20ms. Signal is sampled at 5000 Hertz.
"""

slicedMatrices = []
for j in range(len(normDATA)):
    collect = []
    stepSize = 100 
    windowSize = 125
    dataLength = normDATA[j].shape[1]
    numIters = (dataLength - windowSize) // stepSize + 1
       
    for i in range(numIters):
        where = i * stepSize + windowSize
        start = where - windowSize
        End = where + windowSize
        temp = 1/(2 * windowSize) * (normDATA[j][:, start:End] @ normDATA[j][:, start:End].T)
        collect.append(0.9 * temp + 0.1 * np.trace(temp) * np.eye(31))
    slicedMatrices.append(collect)

In [10]:
"""
Approximately diagonalize the matrices using Frechet mean. Use only TRAIN-VAL data for calculating Frechet mean.
"""

matricesForMean = []
for i in range(9000):
    for j in range(len(slicedMatrices[i])):
        matricesForMean.append(slicedMatrices[i][j])

matricesForMean = np.array(matricesForMean)
manifoldMean = frechetMean()

MEAN = manifoldMean.mean(matricesForMean.reshape(-1, 31, 31))
eigenvalues, eigenvectors = np.linalg.eig(MEAN)

identityMatrix = np.eye(31)
afterMatrices = np.tile(identityMatrix, (len(slicedMatrices), 409, 1, 1)) 
inputLengths = np.zeros((len(slicedMatrices)))
for i in range(len(slicedMatrices)):
    for j in range(len(slicedMatrices[i])):
        temp = eigenvectors.T @ slicedMatrices[i][j] @ eigenvectors
        afterMatrices[i, j] = temp
    inputLengths[i] = len(slicedMatrices[i])

In [None]:
"""np.save("DATA/ckptsLargeVocab/frechetMeanLargeVocab.npy", MEAN)"""

In [11]:
class BaseDataset(Dataset):
    def __init__(self, data, labels, inputLength, targetLength):
        self.data = data 
        self.labels = labels
        self.targetLength = targetLength
        self.inputLength = inputLength

    def __getitem__(self, index):
        inputSeq = self.data[index].astype('float32')  
        targetSeq = self.labels[index]
        inputLength = int(self.inputLength[index])
        targetLength = int(self.targetLength[index])
        return inputSeq, targetSeq, inputLength, targetLength

    def __len__(self):
        return len(self.data)

In [12]:
"""
Train-validation-test split.
"""

trainFeatures = afterMatrices[:8000]
trainLabels = phonemizedLabels[:8000]
trainLabelLengths = labelLengths[:8000]
trainInputLengths = inputLengths[:8000]

valFeatures = afterMatrices[8000:9000]
valLabels = phonemizedLabels[8000:9000]
valLabelLengths = labelLengths[8000:9000]
valInputLengths = inputLengths[8000:9000]

testFeatures = afterMatrices[9000:]
testLabels = phonemizedLabels[9000:]
testLabelLengths = labelLengths[9000:]
testInputLengths = inputLengths[9000:]

In [13]:
trainDataset = BaseDataset(trainFeatures, trainLabels, trainInputLengths, trainLabelLengths)
valDataset = BaseDataset(valFeatures, valLabels, valInputLengths, valLabelLengths)
testDataset = BaseDataset(testFeatures, testLabels, testInputLengths, testLabelLengths)

trainDataloader = DataLoader(trainDataset, batch_size = 32, shuffle = True)
valDataloader = DataLoader(valDataset, batch_size = 32, shuffle = False)
testDataloader = DataLoader(testDataset, batch_size = 32, shuffle = False)

In [14]:
def trainOperation(model,  device, trainLoader, rnnOptimizer, Loss):
    model.train()
    totalLoss = 0
    for inputs, targets, inputLengths, targetLengths in trainLoader:
        inputs, targets = inputs.to(device), targets.to(device)
        inputLengths, targetLengths = inputLengths.to(device), targetLengths.to(device)
        
        rnnOptimizer.zero_grad()

        outputs = model(inputs, inputLengths.cpu())
        loss = Loss(outputs, targets, inputLengths, targetLengths)
        loss.backward()
        rnnOptimizer.step()

        totalLoss += loss.item()
        
    
    return totalLoss / len(trainLoader)


def valOperation(model, device, valLoader, Loss):
    model.eval()
    totalLoss = 0
    with torch.no_grad():
        for inputs, targets, inputLengths, targetLengths in valLoader:
            inputs, targets = inputs.to(device), targets.to(device)
            inputLengths, targetLengths = inputLengths.to(device), targetLengths.to(device)
            
            outputs = model(inputs, inputLengths.cpu()) 
            loss = Loss(outputs, targets, inputLengths, targetLengths)
            totalLoss += loss.item()

    return totalLoss / len(valLoader)

In [15]:
"""
To replicate the PER (phoneme error rate) for various model sizes and layers, change the variable here:
euclideanRnn.RnnNet(41, modelHiddenDimension = 25, device, numLayers = 3).to(device)
"""

dev = "cuda:0"
device = torch.device(dev)

numberEpochs = 100

model = euclideanRnn.RnnNet(41, 25, device, numLayers = 3).to(device)
numParams = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(numParams)
lossFunction = nn.CTCLoss(blank = 40, zero_infinity = True)
rnnOptimizer = optim.Adam(model.parameters(), lr = 0.001, weight_decay = 1e-3)

6348591


In [None]:
"""
Do training.
"""

valLOSS = []
minLOSS = 100
for epoch in range(numberEpochs):
    trainLoss = trainOperation(model, device, trainDataloader, rnnOptimizer, lossFunction)
    valLoss = valOperation(model, device, valDataloader, lossFunction)
    valLOSS.append(valLoss)
    if minLOSS > valLoss:
        minLOSS = valLoss
    torch.save(model.state_dict(), "ckpts/largeVocab/" + str(epoch) + ".pt")
    print(f'Epoch: {epoch + 1}/{numberEpochs}, Training loss: {trainLoss:.4f}, Val loss: {valLoss:.4f}')

In [17]:
np.save("ckpts/largeVocab/valLoss.npy", valLOSS)

In [18]:
valLoss = np.load("ckpts/largeVocab/valLoss.npy")
print(np.min(valLoss))
print(np.argmin(valLoss))
epoch = np.argmin(valLoss)

1.5844007954001427
88


In [19]:
def testOperation(model, device, testLoader, Loss):
    model.eval()
    totalLoss = 0
    Outputs = []
    with torch.no_grad():
        for inputs, targets, inputLengths, targetLengths in testLoader:
            inputs, targets = inputs.to(device), targets.to(device)
            inputLengths, targetLengths = inputLengths.to(device), targetLengths.to(device)
            
            outputs = model(inputs, inputLengths.cpu()) 

            loss = Loss(outputs, targets, inputLengths, targetLengths)
            totalLoss += loss.item()
            Outputs.append(outputs.transpose(0, 1))

    return Outputs, totalLoss / len(testLoader)

In [20]:
"""
Simple beam-search algorithm.
"""

def beamSearch(ctcOutput, testInputLength, beamWidth = 5, blank = 40):
    T, V = ctcOutput.shape
    T = int(min(T, testInputLength))
    beams = {(): (0.0, -np.inf)}  

    for t in range(T):
        newBeams = {}

        for seq, (logProbBlank, logProbNonBlank) in beams.items():
            for symbol in range(V):
                prob = ctcOutput[t, symbol]  

                if symbol == blank:
                    newLogProbBlank = np.logaddexp(logProbBlank, logProbNonBlank) + prob
                    prevLogProbBlank, prevLogProbNonBlank = newBeams.get(seq, (-np.inf, -np.inf))
                    newBeams[seq] = (np.logaddexp(prevLogProbBlank, newLogProbBlank), prevLogProbNonBlank)
                else:
                    newSeq = seq + (symbol,) if not seq or seq[-1] != symbol else seq

                    newLogProbNonBlank = (
                        logProbBlank + prob if seq and seq[-1] == symbol
                        else np.logaddexp(logProbBlank, logProbNonBlank) + prob
                    )

                    if newSeq in newBeams:
                        prevLogProbBlank, prevLogProbNonBlank = newBeams[newSeq]
                        newBeams[newSeq] = (
                            prevLogProbBlank, 
                            np.logaddexp(prevLogProbNonBlank, newLogProbNonBlank)
                        )
                    else:
                        newBeams[newSeq] = (-np.inf, newLogProbNonBlank)

        
        beams = dict(sorted(newBeams.items(), key = lambda x: np.logaddexp(*x[1]), reverse = True)[:beamWidth])

    bestSeq = max(beams.items(), key = lambda x: np.logaddexp(*x[1]))[0]
    return bestSeq

def findClosestTranscription(decodedTranscript, phoneticTranscription):
    
    dist = Levenshtein.distance(decodedTranscript, phoneticTranscription)

    return dist

In [21]:
modelWeight = torch.load("ckpts/largeVocab/" + str(epoch)  + '.pt', weights_only = True)
model.load_state_dict(modelWeight)
output, testLoss = testOperation(model, device, testDataloader, lossFunction)

print("TEST LOSS: ", testLoss)

TEST LOSS:  1.8881642395450222


In [22]:
outs = []
for o in output:
    for oo in o:
        outs.append(oo)

In [23]:
LEVS = []
decodedOut = []
for i in range(1970):
    decodedSymbols = beamSearch(outs[i].cpu().numpy(), testInputLengths[i]) 
    phoneOut = []
    for i in range(len(decodedSymbols)):
        phoneOut.append(PHONE_DEF[decodedSymbols[i]])
    decodedOut.append(phoneOut)

levs = []
phoneLENGTHS = []
for i in range(len(decodedOut)):
    phoneLENGTHS.append(len(phonemizedSentences[9000 + i]))
    levs.append(findClosestTranscription(decodedOut[i], phonemizedSentences[9000 + i]))
LEVS.append(np.mean(levs))

In [24]:
print("Mean length of sentences: ", np.mean(phoneLENGTHS))
print("Mean phoneme errors (insertion errors + deletion errors + substitution errors): ", np.mean(levs))
print("Percent phoneme error: ", np.mean(levs)/np.mean(phoneLENGTHS))

Mean length of sentences:  24.543654822335025
Mean phoneme errors (insertion errors + deletion errors + substitution errors):  12.212690355329949
Percent phoneme error:  0.49759053587309465


In [25]:
"""
Sort the decoded sentences from best-to-worst. Display 100 best decoded sentences.
"""

indices = np.argsort(np.array(levs)/np.array(phoneLENGTHS))
print(indices[:100])

[ 852 1381   85 1525 1233 1920   89 1230   87 1346  839  215  452  564
  973 1116  810  188 1067 1807  988 1130 1139  698 1447  806  214 1055
 1619 1185  540  841 1121   51 1927  897  549  785 1212  634  715  850
 1633 1873  813 1199 1546  129 1332 1218  602  505 1054  706  501 1224
  693 1437  155  910 1671  639   94 1783  202  797 1056 1243  256 1898
 1872 1771  416  811   35  741  306  399  227  705  790 1711 1123  836
  656  539 1027  374 1741 1550  134 1480 1128   95 1740  964 1382 1491
 1003  743]


In [26]:
"""
Visualize decoded sentences.
"""

which = 1381
print("Decoded phoneme sequence: ", decodedOut[which])
print("Ground truth phoneme sequence: ", phonemizedSentences[9000 + which])
print("Ground truth label: ", LABELS[9000 + which])
print(" ")
print("Levenshtein distance between decoded and ground truth sequence: ", Levenshtein.distance(decodedOut[which], phonemizedSentences[9000 + which]))
print("Length of ground truth sequence: ", len(phonemizedSentences[9000 + which]))

Decoded phoneme sequence:  ['IH', 'T', ' ', 'W', 'AA', 'Z', ' ', 'P', 'EY', 'T', ' ', 'F', 'AO', 'R']
Ground truth phoneme sequence:  ['IH', 'T', ' ', 'W', 'AA', 'Z', ' ', 'P', 'EY', 'D', ' ', 'F', 'AO', 'R']
Ground truth label:  it was paid for
 
Levenshtein distance between decoded and ground truth sequence:  1
Length of ground truth sequence:  14
