In [1]:
%load_ext autoreload

In [2]:
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

from g2p_en import G2p
import re

from basicOperations.manifoldOperations import matrixDistance, frechetMean

from rnn import euclideanRnnNato
import math

import pickle
from Levenshtein import distance
import os

import matplotlib.pyplot as plt

In [None]:
""" Proof for Figure 5 and Table 7.
Use the trained checkpoints from natorWordsEuclidean.ipynb to test it here on rainbow passage. 
Reported value in the article is the average from here and checkGrandfather.ipynb.
"""

In [3]:
dev = "cuda:0" 
device = torch.device(dev)

In [4]:
""" Phonemize nato alphabets."""

natoAlphabets = [
    "Alfa", "Bravo", "Charlie", "Delta", "Echo", 
    "Foxtrot", "Golf", "Hotel", "India", "Juliette",
    "Kilo", "Lima", "Mike", "November", "Oscar", 
    "Papa", "Quebec", "Romeo", "Sierra", "Tango",
    "Uniform", "Victor", "Whiskey", "X-ray", "Yankee",
    "Zulu"]

englishAlphabets = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"]

PHONE_DEF = ['AA', 'AE', 'AH', 'AO', 'AY', 'B', 'CH', 'D', 'EH', 'ER', 'EY', 'F', 'G', 'HH', 'IH',
 'IY', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'P', 'R', 'S', 'T', 'UW', 'V', 'W', 'Y', 'Z', 'SIL']

alphabetNumber = {"A": 0, "B": 1, "C": 2, "D": 3, "E": 4, "F": 5, "G": 6, "H": 7, "I": 8, "J": 9, "K": 10, "L": 11, "M": 12, "N": 13, "O": 14, "P": 15, "Q": 16, "R": 17, "S": 18, "T": 19, "U": 20, "V": 21, "W": 22, "X": 23, "Y": 24, "Z": 25}
numberAlphabet = {0: "A", 1: "B", 2: "C", 3: "D", 4: "E", 5: "F", 6: "G", 7: "H", 8: "I", 9: "J", 10: "K", 11: "L", 12: "M", 13: "N", 14: "O", 15: "P", 16: "Q", 17: "R", 18: "S", 19: "T", 20: "U", 21: "V", 22: "W", 23: "X", 24: "Y", 25: "Z"}

def phoneToId(p):
    return PHONE_DEF.index(p)

g2p = G2p()

phonemizedAlphabets = []
for i in range(len(natoAlphabets)):
    alphabet = natoAlphabets[i].strip()
    alphabet = re.sub(r'[^a-zA-Z\- \']', '', alphabet)
    alphabet = alphabet.replace('--', '').lower()
    phones = []
    for p in g2p(alphabet):
        p = re.sub(r'[0-9]', '', p)   
        if re.match(r'[A-Z]+', p):   
            phones.append(p)
    phonemizedAlphabets.append(phones)

phone2index = []
for i in range(len(phonemizedAlphabets)):
    current = phonemizedAlphabets[i]
    phoneID = []
    for j in range(len(current)):
        phoneID.append(phoneToId(current[j]))
    phone2index.append(phoneID)

phonemizedLabels = np.zeros((26, 8))
for i in range(26):
    phonemizedLabels[i, 0:len(phone2index[i])] = phone2index[i]

labelLengths = np.zeros((26))
for i in range(len(phone2index)):
    labelLengths[i] = len(phone2index[i])

In [5]:
def testOperation(model, device, testLoader, Loss):
    model.eval()
    totalLoss = 0
    Outputs = []
    with torch.no_grad():
        for inputs, targets, inputLengths, targetLengths in testLoader:
            inputs, targets = inputs.to(device), targets.to(device)
            inputLengths, targetLengths = inputLengths.to(device), targetLengths.to(device)
            
            outputs = model(inputs, inputLengths.cpu()) 

            loss = Loss(outputs, targets, inputLengths, targetLengths)
            totalLoss += loss.item()
            Outputs.append(outputs.transpose(0, 1))

    return Outputs, totalLoss / len(testLoader)

In [6]:
numberAlphabets = 26
trialsPerAlphabet = 20
numberTrials = numberAlphabets * trialsPerAlphabet
numberChannels = 22
windowLength = 7500

In [42]:
""" Upload data. """

subjectNumber = 4
subject = "Subject" + str(subjectNumber)

DATA = np.load("DATA/" + subject + "/rainbowPassage.npy")

mean = np.mean(DATA, axis = -1)
std = np.std(DATA, axis = -1)
DATA = (DATA - mean[..., np.newaxis])/(std[..., np.newaxis] + 1e-5)

Labels = np.load("DATA/" + subject + "/rainbowPassageLabels.npy")

In [43]:
print(" Number of articulation in the rainbow passage: ", len(Labels))

 Number of articulation in the rainbow passage:  1427


In [44]:
"""Chunk the data. Load eigenvectors. Approximately diagonalize it."""

slicedMatrices = np.zeros((len(Labels), 46, numberChannels, numberChannels))
for j in range(len(Labels)):
    for i in range(46):
        where = i * 150 + 300
        start = where - 300
        End = where + 450
        temp = 1/750 * DATA[j, :, start:End] @ DATA[j, :, start:End].T
        slicedMatrices[j, i] = 0.9 * temp + 0.1 * np.trace(temp) * np.eye(numberChannels)


eigenvectors = np.load("DATA/ckptsNatoWords/eigenVectors" + str(subjectNumber) + ".npy")


afterMatrices = np.zeros((len(Labels), 46, numberChannels, numberChannels))
for i in range(len(Labels)):
    for j in range(46):
        temp = eigenvectors.T @ slicedMatrices[i, j] @ eigenvectors
        afterMatrices[i, j] = temp
print(afterMatrices.shape)

(1427, 46, 22, 22)


In [45]:
class BaseDataset(Dataset):
    def __init__(self, data, labels, targetLength):
        self.data = data 
        self.labels = labels
        self.targetLength = targetLength

    def __getitem__(self, index):
        inputSeq = self.data[index].astype('float32')  
        targetSeq = self.labels[index]
        inputLength = int(self.data.shape[1])
        targetLength = int(self.targetLength[index])
        return inputSeq, targetSeq, inputLength, targetLength

    def __len__(self):
        return len(self.data)

In [46]:
testFeatures = np.zeros((len(Labels), 46, numberChannels, numberChannels))
testLabels = np.zeros((len(Labels), 8))
testLabelLengths = np.zeros((len(Labels)))
for i in range(len(Labels)):
    testFeatures[i] = afterMatrices[i]
    testLabels[i] = phonemizedLabels[Labels[i]]
    testLabelLengths[i] = labelLengths[Labels[i]]

In [47]:
testDataset = BaseDataset(testFeatures, testLabels, testLabelLengths)
testDataloader = DataLoader(testDataset, batch_size = 32, shuffle = False)

In [49]:
"""A simple beamsearch algorithm."""

def beamSearch(ctcOutput, testInputLength, beamWidth = 5, blank = 32):
    _, V = ctcOutput.shape
    T = int(testInputLength)
    beams = [(tuple(), 0.0)] 
    
    for t in range(T):
        newBeams = {}
        for seq, logProb in beams:
            for symbol in range(V):
                newSeq = list(seq)
                
                if symbol == blank:
                    newLogProb = logProb + ctcOutput[t, symbol].item()
                    newSeq = tuple(newSeq)
                
                elif len(seq) == 0 or seq[-1] != symbol:
                    newSeq.append(symbol)
                    newLogProb = logProb + ctcOutput[t, symbol].item()
                    newSeq = tuple(newSeq)
                
                else:
                    newLogProb = logProb + ctcOutput[t, symbol].item()
                    newSeq = tuple(newSeq)
                
                if newSeq in newBeams:
                    newBeams[newSeq] = math.log(math.exp(newBeams[newSeq]) + math.exp(newLogProb))
                else:
                    newBeams[newSeq] = newLogProb

        beams = sorted(newBeams.items(), key=lambda x: x[1], reverse=True)[:beamWidth]
    
    return beams[0][0]

def findBestMatch(prediction, alphabetSequences):
    
    minDistance = float('inf')
    closestKeys = []

    for key, phoneticTranscription in alphabetSequences.items():
        dist = distance(prediction, phoneticTranscription)
        
        if dist < minDistance:
            minDistance = dist
            closestKeys = [key]
        elif dist == minDistance:
            minDistance = dist
            closestKeys.append(key)
        

    return closestKeys, minDistance

alphabetSequences = {}
for i in range(26):
    alphabetSequences[englishAlphabets[i]] = phone2index[i]

In [50]:
DEVICE = "cuda:0"

numberEpochs = 100

model = euclideanRnnNato.RnnNet(33, 23, device, numLayers = 1).to(device)
numParams = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(numParams)
lossFunction = nn.CTCLoss(blank = 32, zero_infinity = True)
rnnOptimizer = optim.Adam(model.parameters(), lr = 0.001, weight_decay = 1e-3)

1280121


In [None]:
"""FOLDERPATH = "ckpts/natoEuclidean/"
valLoss = np.load(FOLDERPATH + "valLoss.npy")
chosenEpoch = None
sortedIndices = np.argsort(valLoss)

for idx in sortedIndices:
    if idx > 0 and idx < len(valLoss) - 1: 
        currentLoss = valLoss[idx]
        threshold = currentLoss * 1.2

        if valLoss[idx - 1] <= threshold and valLoss[idx + 1] <= threshold:
            chosenEpoch = idx
            break

valEpoch = chosenEpoch 
print(valEpoch)"""

66


In [51]:
checkpoint = torch.load("DATA/ckptsNatoWords/subject" + str(subjectNumber) + "Euclidean.pt", weights_only = True)
model.load_state_dict(checkpoint)
output, valLoss = testOperation(model, device, testDataloader, lossFunction)
output = torch.concatenate(output)
print(valLoss)

2.9231127050187853


In [52]:
"""Calculate PER."""

lev = []
labelLength = []
for i in range(len(Labels)):
    decodedSymbols = beamSearch(output[i].cpu().numpy(), 46)
    lev.append(distance(decodedSymbols, alphabetSequences[numberAlphabet[Labels[i]]]))
    labelLength.append(len(alphabetSequences[numberAlphabet[Labels[i]]]))

print("PER: ", np.mean(lev)/np.mean(labelLength))

PER:  0.5812924166420773


In [53]:
"""Calculate CER."""

corrects = 0
for i in range(len(Labels)):
    decodedSymbols = beamSearch(output[i].cpu().numpy(), 46)
    bestLetters  = findBestMatch(decodedSymbols, alphabetSequences)
    if len(bestLetters[0]) == 1:
        if numberAlphabet[Labels[i]] in bestLetters[0]:
            corrects += 1

print("CER: ", corrects/len(Labels))

CER:  0.4386825508058865
