In [1]:
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

from g2p_en import G2p
import re

from basicOperations.manifoldOperations import matrixDistance, frechetMean
import torch.nn.utils as utils

from rnn import euclideanRnn
import math

import pickle
import Levenshtein
import os

import matplotlib.pyplot as plt

In [None]:
"""
Proof for table 1, figure 2, and figure 3. Icefall code is from https://github.com/k2-fsa/icefall
"""

In [None]:
"""
Train LARGE-VOCAB EMG-to-phoneme conversion.

For description of the data, please see largeVocabDataVisualization.ipynb

Unlike data SMALL-VOCAB, there are no timestamps between words within a sentence. 

Given a sentence, you decode it fully using CTC loss. The pipeline resembles standard speech-to-text (ASR) techniques.
"""

"""https://pypi.org/project/Levenshtein/ - install this Lev distance."""

In [2]:
"""
Open Data.
"""

with open("DATA/dataLargeVocab.pkl", "rb") as file:
    DATA = pickle.load(file)

with open("DATA/labelsLargeVocab.pkl", "rb") as file:
    LABELS = pickle.load(file)

In [3]:
""" Diag = TRUE or FALSE. Raw SPD matrices or approximately diagonalized?"""
DIAG = True

In [6]:
"""
English phoneme definitions.
"""
tok2id = {}
with open("DATA/ckptsLargeVocab/lang_phone/tokens.txt") as f:
    for line in f:
        s, i = line.strip().split()
        i = int(i)
        if s == "<eps>" or s.startswith("#"):
            continue
        tok2id[s] = i
PHONE_DEF = tok2id


def phoneToId(p):
    return PHONE_DEF[p]

g2p = G2p()

In [7]:
"""
Phonemize the sentences.
"""

phonemizedSentences = []

for i in range(len(LABELS)):
    phones = []
    for p in g2p(LABELS[i]): 
        p = re.sub(r'[0-9]', '', p)   
        if re.match(r'[A-Z]+', p): 
            phones.append(p)
    phonemizedSentences.append(phones)

In [8]:
"""
Convert phone-to-indices using look-up dictionary PHONE_DEF.
"""

phoneIndexedSentences = []
for i in range(len(phonemizedSentences)):
    current = phonemizedSentences[i]
    phoneID = []
    for j in range(len(current)):
        phoneID.append(phoneToId(current[j]))
    phoneIndexedSentences.append(phoneID)

In [9]:
def tokenIdToClassIdx(tokenId: int) -> int:
    return tokenId - 1   

def phoneSeqToClassIdxSeq(phoneSeq):
    return [tokenIdToClassIdx(PHONE_DEF[p]) for p in phoneSeq]

classIndexedSentences = [phoneSeqToClassIdxSeq(seq) for seq in phonemizedSentences]

In [10]:
"""
Pad the phone transcribed sentences to a common length (to be used with CTC loss).
"""

phonemizedLabels = np.zeros((len(classIndexedSentences), 76)) - 1
for i in range(len(classIndexedSentences)):
    phonemizedLabels[i, 0:len(classIndexedSentences[i])] = classIndexedSentences[i]

labelLengths = np.zeros((len(classIndexedSentences)))
for i in range(len(classIndexedSentences)):
    labelLengths[i] = len(classIndexedSentences[i])

In [11]:
"""
z-normalize the data along the time dimension.
"""

normDATA = []
for i in range(len(DATA)):
    Mean = np.mean(DATA[i], axis = -1)
    Std = np.std(DATA[i], axis = -1)
    normDATA.append((DATA[i] - Mean[..., np.newaxis])/Std[..., np.newaxis])

In [12]:
"""
Slice the matrices into 50ms segments with a step size of 20ms. Signal is sampled at 5000 Hertz.
"""

slicedMatrices = []
for j in range(len(normDATA)):
    collect = []
    stepSize = 100 
    windowSize = 125
    dataLength = normDATA[j].shape[1]
    numIters = (dataLength - windowSize) // stepSize + 1
       
    for i in range(numIters):
        where = i * stepSize + windowSize
        start = where - windowSize
        End = where + windowSize
        temp = 1/(2 * windowSize) * (normDATA[j][:, start:End] @ normDATA[j][:, start:End].T)
        collect.append(0.9 * temp + 0.1 * np.trace(temp) * np.eye(31))
    slicedMatrices.append(collect)

In [13]:
"""
Approximately diagonalize the matrices using Frechet mean. Use only TRAIN-VAL data for calculating Frechet mean.
"""

matricesForMean = []
for i in range(9000):
    for j in range(len(slicedMatrices[i])):
        matricesForMean.append(slicedMatrices[i][j])

matricesForMean = np.array(matricesForMean)
manifoldMean = frechetMean()

MEAN = manifoldMean.mean(matricesForMean.reshape(-1, 31, 31))
eigenvalues, eigenvectors = np.linalg.eig(MEAN)

identityMatrix = np.eye(31)
afterMatrices = np.tile(identityMatrix, (len(slicedMatrices), 409, 1, 1)) 
inputLengths = np.zeros((len(slicedMatrices)))
for i in range(len(slicedMatrices)):
    for j in range(len(slicedMatrices[i])):
        if DIAG:
            temp = eigenvectors.T @ slicedMatrices[i][j] @ eigenvectors
        else:
            temp = slicedMatrices[i][j]
        afterMatrices[i, j] = temp
    inputLengths[i] = len(slicedMatrices[i])

In [None]:
"""np.save("DATA/ckptsLargeVocab/frechetMeanLargeVocab.npy", MEAN)"""

In [14]:
class BaseDataset(Dataset):
    def __init__(self, data, labels, inputLength, targetLength):
        self.data = data 
        self.labels = labels
        self.targetLength = targetLength
        self.inputLength = inputLength

    def __getitem__(self, index):
        inputSeq = self.data[index].astype('float32')  
        targetSeq = self.labels[index]
        inputLength = int(self.inputLength[index])
        targetLength = int(self.targetLength[index])
        return inputSeq, targetSeq, inputLength, targetLength

    def __len__(self):
        return len(self.data)

In [15]:
"""
Train-validation-test split.
"""

trainFeatures = afterMatrices[:8000]
trainLabels = phonemizedLabels[:8000]
trainLabelLengths = labelLengths[:8000]
trainInputLengths = inputLengths[:8000]

valFeatures = afterMatrices[8000:9000]
valLabels = phonemizedLabels[8000:9000]
valLabelLengths = labelLengths[8000:9000]
valInputLengths = inputLengths[8000:9000]

testFeatures = afterMatrices[9000:]
testLabels = phonemizedLabels[9000:]
testLabelLengths = labelLengths[9000:]
testInputLengths = inputLengths[9000:]

In [16]:
trainDataset = BaseDataset(trainFeatures, trainLabels, trainInputLengths, trainLabelLengths)
valDataset = BaseDataset(valFeatures, valLabels, valInputLengths, valLabelLengths)
testDataset = BaseDataset(testFeatures, testLabels, testInputLengths, testLabelLengths)

trainDataloader = DataLoader(trainDataset, batch_size = 32, shuffle = True)
valDataloader = DataLoader(valDataset, batch_size = 32, shuffle = False)
testDataloader = DataLoader(testDataset, batch_size = 32, shuffle = False)

In [17]:
def trainOperation(model,  device, trainLoader, rnnOptimizer, Loss):
    model.train()
    totalLoss = 0
    for inputs, targets, inputLengths, targetLengths in trainLoader:
        inputs, targets = inputs.to(device), targets.to(device)
        inputLengths, targetLengths = inputLengths.to(device), targetLengths.to(device)
        
        rnnOptimizer.zero_grad()

        outputs = model(inputs, inputLengths.cpu())
        loss = Loss(outputs, targets, inputLengths, targetLengths)
        loss.backward()
        rnnOptimizer.step()

        totalLoss += loss.item()
        
    
    return totalLoss / len(trainLoader)


def valOperation(model, device, valLoader, Loss):
    model.eval()
    totalLoss = 0
    with torch.no_grad():
        for inputs, targets, inputLengths, targetLengths in valLoader:
            inputs, targets = inputs.to(device), targets.to(device)
            inputLengths, targetLengths = inputLengths.to(device), targetLengths.to(device)
            
            outputs = model(inputs, inputLengths.cpu()) 
            loss = Loss(outputs, targets, inputLengths, targetLengths)
            totalLoss += loss.item()

    return totalLoss / len(valLoader)

In [18]:
"""
To replicate the PER (phoneme error rate) for various model sizes and layers, change the variable here:
euclideanRnn.RnnNet(41, modelHiddenDimension = 25, device, numLayers = 3).to(device)
"""

dev = "cuda:0"
device = torch.device(dev)

numberEpochs = 100

model = euclideanRnn.RnnNet(41, 25, device, numLayers = 3).to(device)
numParams = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(numParams)
lossFunction = nn.CTCLoss(blank = 40, zero_infinity = True)
rnnOptimizer = optim.Adam(model.parameters(), lr = 0.001, weight_decay = 1e-3)

6348591


In [None]:
"""
Do training.
"""

valLOSS = []
minLOSS = 100
for epoch in range(numberEpochs):
    trainLoss = trainOperation(model, device, trainDataloader, rnnOptimizer, lossFunction)
    valLoss = valOperation(model, device, valDataloader, lossFunction)
    valLOSS.append(valLoss)
    if minLOSS > valLoss:
        minLOSS = valLoss
    torch.save(model.state_dict(), "ckpts/largeVocab/" + str(epoch) + ".pt")
    print(f'Epoch: {epoch + 1}/{numberEpochs}, Training loss: {trainLoss:.4f}, Val loss: {valLoss:.4f}')

In [None]:
"""np.save("ckpts/largeVocab/valLoss.npy", valLOSS)"""

In [20]:
valLoss = np.load("ckpts/largeVocab/valLoss.npy")
print(np.min(valLoss))
print(np.argmin(valLoss))
epoch = np.argmin(valLoss)

1.6524379551410675
90


In [19]:
def testOperation(model, device, testLoader, Loss):
    model.eval()
    totalLoss = 0
    Outputs = []
    with torch.no_grad():
        for inputs, targets, inputLengths, targetLengths in testLoader:
            inputs, targets = inputs.to(device), targets.to(device)
            inputLengths, targetLengths = inputLengths.to(device), targetLengths.to(device)
            
            outputs = model(inputs, inputLengths.cpu()) 

            loss = Loss(outputs, targets, inputLengths, targetLengths)
            totalLoss += loss.item()
            Outputs.append(outputs.transpose(0, 1))

    return Outputs, totalLoss / len(testLoader)

In [26]:
"""
Simple beam-search algorithm.
"""

def ctcPrefixBeamSearch(
    logProbs,
    testLen = None,
    beamSize = 50,
    blank = 40,
    topk = None,
    allowDoubles = True,
):
    
    lp = np.asarray(logProbs)
    Ttotal, V = lp.shape
    T = Ttotal if testLen is None else int(min(testLen, Ttotal))

    beams = {(): (0.0, -np.inf)}

    def add(store, seq, addPb, addPnb):
        if seq in store:
            pb, pnb = store[seq]
            if addPb  != -np.inf: pb  = np.logaddexp(pb,  addPb)
            if addPnb != -np.inf: pnb = np.logaddexp(pnb, addPnb)
            store[seq] = (pb, pnb)
        else:
            store[seq] = (addPb, addPnb)

    for t in range(T):
        row = lp[t] 
        new = {}

        if topk is not None and topk < V:
            cand = np.argpartition(row, -topk)[-topk:]
            if blank not in cand:
                worstIdx = cand[np.argmin(row[cand])]
                cand[cand == worstIdx] = blank
        else:
            cand = range(V)

        for seq, (pb, pnb) in beams.items():
            add(new, seq, np.logaddexp(pb, pnb) + row[blank], -np.inf)

            last = seq[-1] if seq else None

            for c in cand:
                if c == blank:
                    continue
                pC = row[c]

                if c == last:
            
                    add(new, seq, -np.inf, pnb + pC)

                    if allowDoubles:
                        add(new, seq + (c,), -np.inf, pb + pC)
                else:
                    add(new, seq + (c,), -np.inf, np.logaddexp(pb, pnb) + pC)

        if len(new) > beamSize:
            items = sorted(new.items(),
                           key = lambda kv: np.logaddexp(*kv[1]),
                           reverse = True)[:beamSize]
            beams = dict(items)
        else:
            beams = new

    bestSeq = max(beams.items(), key = lambda kv: np.logaddexp(*kv[1]))[0]
    return bestSeq

def findClosestTranscription(decodedTranscript, phoneticTranscription):
    
    dist = Levenshtein.distance(decodedTranscript, phoneticTranscription)

    return dist

In [None]:
modelWeight = torch.load("DATA/ckptsLargeVocab/ckptWithoutSpaces.pt", weights_only = True)
"""modelWeight = torch.load("ckpts/largeVocab/" + str(90)  + '.pt', weights_only = True)"""
model.load_state_dict(modelWeight)
output, testLoss = testOperation(model, device, testDataloader, lossFunction)

print("TEST LOSS: ", testLoss)

TEST LOSS:  1.9241562139603399


In [22]:
outs = []
for o in output:
    for oo in o:
        outs.append(oo)

In [23]:
print(len(outs))
print(outs[0].shape)

1970
torch.Size([192, 41])


In [24]:
PHONE_DEF1 = {}
for k, v in PHONE_DEF.items():
    PHONE_DEF1[v - 1] = k

In [27]:
LEVS = []
decodedOut = []
for i in range(1970):
    decodedSymbols = ctcPrefixBeamSearch(outs[i].cpu().numpy(), testInputLengths[i]) 
    phoneOut = []
    for i in range(len(decodedSymbols)):
        phoneOut.append(PHONE_DEF1[decodedSymbols[i]])
    decodedOut.append(phoneOut)

levs = []
phoneLENGTHS = []
for i in range(len(decodedOut)):
    phoneLENGTHS.append(len(phonemizedSentences[9000 + i]))
    levs.append(findClosestTranscription(decodedOut[i], phonemizedSentences[9000 + i]))
LEVS.append(np.mean(levs))

In [28]:
print("Mean length of sentences: ", np.mean(phoneLENGTHS))
print("Mean phoneme errors (insertion errors + deletion errors + substitution errors): ", np.mean(levs))
print("Percent phoneme error: ", np.sum(levs)/np.sum(phoneLENGTHS))

Mean length of sentences:  19.590862944162435
Mean phoneme errors (insertion errors + deletion errors + substitution errors):  9.93502538071066
Percent phoneme error:  0.5071254599160492


In [29]:
"""
Sort the decoded sentences from best-to-worst. Display 100 best decoded sentences.
"""

indices = np.argsort(np.array(levs)/np.array(phoneLENGTHS))
print(indices[:100])

[ 991 1341 1562 1543 1381 1687 1214  452 1920  693  841  685  214 1619
  190 1180  540  669 1055 1436 1798 1030  602 1756   33 1872 1771  361
  682  192 1222  633  850  852  976  881  564  132  762  313  155  620
 1473  973   34 1855  202 1478   59  839  989  215 1150  225 1119 1892
   98  522  939  563  288 1184 1185 1035 1315 1751 1962  818   94 1633
 1901  771 1542 1467 1924 1135 1326 1123  454 1059  417  536  784  637
  671 1927  974  775 1879 1002 1179  924  996 1800  354  449 1056  302
 1769 1081]


In [30]:
"""
Visualize decoded sentences.
"""

which = 991
print("Decoded phoneme sequence: ", decodedOut[which])
print("Ground truth phoneme sequence: ", phonemizedSentences[9000 + which])
print("Ground truth label: ", LABELS[9000 + which])
print(" ")
print("Levenshtein distance between decoded and ground truth sequence: ", Levenshtein.distance(decodedOut[which], phonemizedSentences[9000 + which]))
print("Length of ground truth sequence: ", len(phonemizedSentences[9000 + which]))

Decoded phoneme sequence:  ['DH', 'EY', 'V', 'G', 'AA', 'T', 'AH', 'N', 'AY', 'S', 'W', 'AH', 'N']
Ground truth phoneme sequence:  ['DH', 'EY', 'V', 'G', 'AA', 'T', 'AH', 'N', 'AY', 'S', 'W', 'AH', 'N']
Ground truth label:  theyve got a nice one
 
Levenshtein distance between decoded and ground truth sequence:  0
Length of ground truth sequence:  13
