In [1]:
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

from g2p_en import G2p
import re

from basicOperations.manifoldOperations import matrixDistance, frechetMean
import torch.nn.utils as utils

from rnn import manifoldRnn
import math

import pickle
from Levenshtein import distance
import os

import Levenshtein

In [None]:
"""
Proof for table 2, figure 4.
"""

"""
Train SMALL-VOCAB EMG-to-phoneme conversion.

This is a small corpora of data containing of 500 sentences.

Each sentence is of the following format. WEEKDAY - MONTH - DATE - YEAR.
Unlike the data LARGE-VOCAB, we have time-stamps between weekday, month, date, and year (this is a much simpler dataset). 

A sentence was displayed on GUI and cue was given when to articulate each of the WEEKDAY - MONTH - DATE - YEAR. 

WEEKDAY - articulated in a window of duration 2s.
MONTH - articulated in a window of duration 2s.
DATE - articulated in a window of duration 2s.
YEAR - articulated in a window of duration 3s.

Split each sentence at these boundaries and train the model.

This small data can be used for fine-grained analysis such as to developing algoriths to demarcate speech and non-speech using EMG. 

"""

In [2]:
DATA = np.load("DATA/dataSmallVocab.npy")
LABELS = np.load("DATA/labelsSmallVocab.npy")[:499]

In [3]:
""" z-normalize the data. """

Mean = np.mean(DATA, axis = -1)
Std = np.std(DATA, axis = -1)
DATA = (DATA - Mean[..., np.newaxis])/Std[..., np.newaxis]

In [4]:
"""Convert words to thier phoneme sequences."""

WORDS = []
for i in range(len(LABELS)):
    words = LABELS[i]
    words = [word.replace("-", " ").strip().lower() for word in words]
    WORDS.append(words)

allWORDS = []
for i in range(len(WORDS)):
    for j in range(len(WORDS[i])):
        allWORDS.append(WORDS[i][j])

PHONE_DEF = ['AO', 'CH', 'IY', 'AA', 'W', 'S', 'IH', 'K', 'EY', 'JH', 'Y', 'N', 'OW', 'M', 'P', 'T', 'B', 'AY', 'UW', 'R', 'G', 'EH', 'Z', 'TH', 'AW', 'HH', 'AH', 'AE', 'L', 'ER', 'F', 'V', 'D', 'SIL']

def phoneToId(p):
    return PHONE_DEF.index(p)

g2p = G2p()

phonemizedWords = []
for i in range(len(allWORDS)):
    word = re.sub(r'[^a-zA-Z\- \']', '', allWORDS[i])
    phones = []
    for p in g2p(word):
        p = re.sub(r'[0-9]', '', p)   
        if re.match(r'[A-Z]+', p):   
            phones.append(p)
    phonemizedWords.append(phones)

phone2index = []
for i in range(len(phonemizedWords)):
    current = phonemizedWords[i]
    phoneID = []
    for j in range(len(current)):
        phoneID.append(phoneToId(current[j]))
    phone2index.append(phoneID)

phonemizedLabels = np.zeros((499 * 4, 18))
for i in range(499 * 4):
    phonemizedLabels[i, 0:len(phone2index[i])] = phone2index[i]

labelLengths = np.zeros((499 * 4))
for i in range(len(phone2index)):
    labelLengths[i] = len(phone2index[i])

In [5]:
"""Chunk sentences at word boundaries."""

dataChunk = []
for i in range(len(DATA)):
    seg1 = DATA[i, :, :10000]
    seg2 = DATA[i, :, 10000:20000]
    seg3 = DATA[i, :, 20000:30000]
    seg4 = DATA[i, :, 30000:]

    dataChunk.append(seg1)
    dataChunk.append(seg2)
    dataChunk.append(seg3)
    dataChunk.append(seg4)


slicedMatrices = []
for j in range(499 * 4):
    collect = []

    if dataChunk[j].shape[1] == 10000:
       
        for i in range(38):
            where = i * 250 + 250
            start = where - 250
            End = where + 250
            temp = 1/250 * (dataChunk[j][:, start:End] @ dataChunk[j][:, start:End].T)
            collect.append(0.9 * temp + 0.1 * np.trace(temp) * np.eye(31))
        slicedMatrices.append(collect)

    elif dataChunk[j].shape[1] == 15000:
        
        for i in range(58):
            where = i * 250 + 250
            start = where - 250
            End = where + 250
            temp = 1/250 * (dataChunk[j][:, start:End] @ dataChunk[j][:, start:End].T)
            collect.append(0.9 * temp + 0.1 * np.trace(temp) * np.eye(31))
        slicedMatrices.append(collect)

In [6]:
""" Approximately diagonalize the covariances matrices using Frechet mean. Use only TRAIN-VAL set for ciomputing Frechet mean."""

matricesForMean = []
for i in range(1600):
    for j in range(len(slicedMatrices[i])):
        matricesForMean.append(slicedMatrices[i][j])

matricesForMean = np.array(matricesForMean)
manifoldMean = frechetMean()

MEAN = manifoldMean.mean(matricesForMean.reshape(-1, 31, 31))
eigenvalues, eigenvectors = np.linalg.eig(MEAN)

identityMatrix = np.eye(31)
afterMatrices = np.tile(identityMatrix, (499 * 4, 58, 1, 1)) 
inputLengths = np.zeros((499 * 4))
for i in range(499 * 4):
    for j in range(len(slicedMatrices[i])):
        temp = eigenvectors.T @ slicedMatrices[i][j] @ eigenvectors
        afterMatrices[i, j] = temp
    inputLengths[i] = len(slicedMatrices[i])

In [7]:
class BaseDataset(Dataset):
    def __init__(self, data, labels, inputLength, targetLength):
        self.data = data 
        self.labels = labels
        self.targetLength = targetLength
        self.inputLength = inputLength

    def __getitem__(self, index):
        inputSeq = self.data[index].astype('float32')  
        targetSeq = self.labels[index]
        inputLength = int(self.inputLength[index])
        targetLength = int(self.targetLength[index])
        return inputSeq, targetSeq, inputLength, targetLength

    def __len__(self):
        return len(self.data)

In [8]:
"""TRAIN-VAL-TEST split."""

trainFeatures = np.zeros((370 * 4, 58, 31, 31))
trainLabels = np.zeros((370 * 4, 18))
trainLabelLengths = np.zeros((370 * 4))
trainInputLengths = np.zeros((370 * 4))

valFeatures = np.zeros((30 * 4, 58, 31, 31))
valLabels = np.zeros((30 * 4, 18))
valLabelLengths = np.zeros((30 * 4))
valInputLengths = np.zeros((30 * 4))

testFeatures = np.zeros((99 * 4, 58, 31, 31))
testLabels = np.zeros((99 * 4, 18))
testLabelLengths = np.zeros((99 * 4))
testInputLengths = np.zeros((99 * 4))



trainFeatures = afterMatrices[:1480]
trainLabels = phonemizedLabels[:1480]
trainLabelLengths = labelLengths[:1480]
trainInputLengths = inputLengths[:1480]

valFeatures = afterMatrices[1480:1600]
valLabels = phonemizedLabels[1480:1600]
valLabelLengths = labelLengths[1480:1600]
valInputLengths = inputLengths[1480:1600]

testFeatures = afterMatrices[1600:]
testLabels = phonemizedLabels[1600:]
testLabelLengths = labelLengths[1600:]
testInputLengths = inputLengths[1600:]

In [9]:
trainDataset = BaseDataset(trainFeatures, trainLabels, trainInputLengths, trainLabelLengths)
valDataset = BaseDataset(valFeatures, valLabels, valInputLengths, valLabelLengths)
testDataset = BaseDataset(testFeatures, testLabels, testInputLengths, testLabelLengths)


trainDataloader = DataLoader(trainDataset, batch_size = 32, shuffle = True)
valDataloader = DataLoader(valDataset, batch_size = 32, shuffle = False)
testDataloader = DataLoader(testDataset, batch_size = 32, shuffle = False)

In [10]:
def trainOperation(model,  device, trainLoader, rnnOptimizer, Loss):
    model.train()
    totalLoss = 0
    for inputs, targets, inputLengths, targetLengths in trainLoader:
        inputs, targets = inputs.to(device), targets.to(device)
        inputLengths, targetLengths = inputLengths.to(device), targetLengths.to(device)
        
        rnnOptimizer.zero_grad()

        outputs = model(inputs)
        loss = Loss(outputs, targets, inputLengths, targetLengths)
        loss.backward()
        rnnOptimizer.step()

        totalLoss += loss.item()
        
    
    return totalLoss / len(trainLoader)


def valOperation(model, device, valLoader, Loss):
    model.eval()
    totalLoss = 0
    with torch.no_grad():
        for inputs, targets, inputLengths, targetLengths in valLoader:
            inputs, targets = inputs.to(device), targets.to(device)
            inputLengths, targetLengths = inputLengths.to(device), targetLengths.to(device)
            
            outputs = model(inputs) 
            loss = Loss(outputs, targets, inputLengths, targetLengths)
            totalLoss += loss.item()

    return totalLoss / len(valLoader)

In [23]:
""" Set ODE = TRUE or FALSE. For different sizes, change the size of the hidden.

If you are loading ckptSmallVoacbManifoldNoODE.pt, set hidden dimesnion to 32.
If you are loading ckptSmallVocabManifoldODE.pt, set hidden dimension to 15.
"""

dev = "cuda:0"
device = torch.device(dev)

numberEpochs = 100

model = manifoldRnn.spdRnnNet(34, hidden = 32, ODE = False, device = device).to(device)
numParams = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(numParams)
lossFunction = nn.CTCLoss(blank = 33, zero_infinity = True)
rnnOptimizer = optim.Adam(model.parameters(), lr = 0.001, weight_decay = 1e-3)

5145634


In [None]:
"""Train the model."""
valLOSS = []
minLOSS = 100
for epoch in range(numberEpochs):
    trainLoss = trainOperation(model, device, trainDataloader, rnnOptimizer, lossFunction)
    valLoss = valOperation(model, device, valDataloader, lossFunction)
    valLOSS.append(valLoss)
    if minLOSS > valLoss:
        minLOSS = valLoss
    torch.save(model.state_dict(), "ckpts/smallVocabManifold/" + str(epoch) + ".pt")
    print(f'Epoch: {epoch + 1}/{numberEpochs}, Training loss: {trainLoss:.4f}, Val loss: {valLoss:.4f}')

In [None]:
"""np.save("ckpts/smallVocabManifold/valLoss.npy", valLOSS)"""

In [24]:
"""
Simple beam-search algorithm.
"""

def beamSearch(ctcOutput, testInputLength, beamWidth = 5, blank = 33):
    _, V = ctcOutput.shape
    T = int(testInputLength)
    beams = [(tuple(), 0.0)] 
    
    for t in range(T):
        newBeams = {}
        for seq, logProb in beams:
            for symbol in range(V):
                newSeq = list(seq)
                
                if symbol == blank:
                    newLogProb = logProb + ctcOutput[t, symbol].item()
                    newSeq = tuple(newSeq)
                
                elif len(seq) == 0 or seq[-1] != symbol:
                    newSeq.append(symbol)
                    newLogProb = logProb + ctcOutput[t, symbol].item()
                    newSeq = tuple(newSeq)
                
                else:
                    newLogProb = logProb + ctcOutput[t, symbol].item()
                    newSeq = tuple(newSeq)
                
                if newSeq in newBeams:
                    newBeams[newSeq] = math.log(math.exp(newBeams[newSeq]) + math.exp(newLogProb))
                else:
                    newBeams[newSeq] = newLogProb

        beams = sorted(newBeams.items(), key = lambda x: x[1], reverse = True)[:beamWidth]
    
    return beams[0][0]

In [25]:
def findClosestTranscription(decodedTranscript, phoneticTranscription):
    
    dist = Levenshtein.distance(decodedTranscript, phoneticTranscription)

    return dist

In [26]:
def testOperation(model, device, valLoader, Loss):
    model.eval()
    totalLoss = 0
    Outputs = []
    with torch.no_grad():
        for inputs, targets, inputLengths, targetLengths in valLoader:
            inputs, targets = inputs.to(device), targets.to(device)
            inputLengths, targetLengths = inputLengths.to(device), targetLengths.to(device)
            
            outputs = model(inputs) 

            loss = Loss(outputs, targets, inputLengths, targetLengths)
            totalLoss += loss.item()
            Outputs.append(outputs.transpose(0, 1))

    return Outputs, totalLoss / len(valLoader)

In [None]:
"""ortedIndices = np.argsort(valLOSS)

chosenEpoch = None

for idx in sortedIndices:
    if idx > 0 and idx < len(valLOSS) - 1: 
        currentLoss = valLOSS[idx]
        threshold = currentLoss * 1.2

        if valLOSS[idx - 1] <= threshold and valLOSS[idx + 1] <= threshold:
            chosenEpoch = idx
            break

epoch = chosenEpoch
print(epoch)"""

In [28]:
modelWeight = torch.load("DATA/ckptsSmallVocab/ckptSmallVocabManifoldNoODE.pt", weights_only = True)
model.load_state_dict(modelWeight)
output, testLoss = testOperation(model, device, testDataloader, lossFunction)
output = torch.concatenate(output)

print("TEST LOSS: ", testLoss)

TEST LOSS:  0.7542564123868942


In [29]:
decodedOut = []
for i in range(396):
    decodedSymbols = beamSearch(output[i].cpu().numpy(), testInputLengths[i])
    phoneOut = []
    for i in range(len(decodedSymbols)):
        phoneOut.append(PHONE_DEF[decodedSymbols[i]])
    decodedOut.append(phoneOut)

In [30]:
"""Visualize decoded words."""

which = 0
print("Decoded phoneme sequence: ", decodedOut[which])
print("Ground truth phoneme sequence: ", phonemizedWords[1600 + which])
print("Ground truth label: ", allWORDS[1600 + which])

Decoded phoneme sequence:  ['W', 'EH', 'N', 'Z', 'D', 'IY']
Ground truth phoneme sequence:  ['W', 'EH', 'N', 'Z', 'D', 'IY']
Ground truth label:  wednesday


In [31]:
levs = []
phoneLENGTHS = []
for i in range(len(decodedOut)):
    phoneLENGTHS.append(len(phonemizedWords[1600 + i]))
    levs.append(findClosestTranscription(decodedOut[i], phonemizedWords[1600 + i]))

In [32]:
print("Mean length of sentences: ", np.mean(phoneLENGTHS))
print("Mean phoneme error rate (insertion errors + deletion errors + substitution errors): ", np.mean(levs))
print("Percent phoneme error: ", np.mean(levs)/np.mean(phoneLENGTHS))

Mean length of sentences:  7.825757575757576
Mean phoneme error rate (insertion errors + deletion errors + substitution errors):  1.1893939393939394
Percent phoneme error:  0.15198451113262343


In [None]:
"""Calculate word error rate."""

In [33]:
"""Small vocab word-to-phoneme dictionary."""

with open('DATA/ckptsSmallVocab/smallVocabDict.pkl', 'rb') as f:
    smallVocabDict = pickle.load(f)

In [34]:
def findClosestTranscriptionWords(decodedTranscript, combinedDict):
    
    minDistance = float('inf')
    closestKeys = []

    for key, phoneticTranscription in combinedDict.items():
        dist = distance(decodedTranscript, phoneticTranscription)
        
        if dist < minDistance:
            minDistance = dist
            closestKeys = [key]
        elif dist == minDistance:
            minDistance = dist
            closestKeys.append(key)
        

    return closestKeys, minDistance

def checkCorrectness(actual, predicted):
    
    actualWords = actual.split()
    correctCount = 0
    
    for i, prediction in enumerate(predicted):
        predicted[i] = prediction.split()
        
    actual = actual.split()
    compare = len(predicted)

    for i, word in enumerate(actual):
        tally = 0
        for j in range(len(predicted)):
            if i < len(predicted[j]):
                if word == predicted[j][i]:
                    tally += 1
        if tally == compare:
            correctCount += 1
            
    return correctCount

In [35]:
finalDecode = []
for i in range(len(decodedOut)):
    decodedResult, ld = findClosestTranscriptionWords(decodedOut[i], smallVocabDict)
    finalDecode.append(decodedResult)

levDistance = 0
for i in range(len(decodedOut)):
    levDistance += distance(decodedOut[i], phonemizedWords[1600 + i])

totalWords = 0
for i in range(396):
    w = allWORDS[1600 + i]
    w = w.replace("-", " ").lower()
    x = w.split(" ")
    totalWords += len(x)

singleCorrects = 0
singleTotal = 0

for i in range(396):
    w = allWORDS[1600 + i]
    w = w.replace("-", " ").lower()
    x = w.split(" ")
    
    c = finalDecode[i]

    if len(x) == 1:
        singleTotal += 1
        if len(c) == 1:
            if x[0]  in c:
                singleCorrects += 1

notSingleCorrects = 0
notSingleTotal = 0

for i in range(396):
    w = allWORDS[1600 + i]
    
    w = w.replace("-", " ").lower()
    x = w.split(" ")
    
    c = finalDecode[i]
    if len(x) > 1:
        notSingleTotal += len(x)
        
        corrects = checkCorrectness(w, c)
        notSingleCorrects += corrects

print("Word decoding accuracy: ", (notSingleCorrects + singleCorrects)/totalWords)
print(" ")

LEV = []
ACC = []
LEV.append(levDistance)
ACC.append((notSingleCorrects + singleCorrects)/totalWords)

Word decoding accuracy:  0.8155339805825242
 
