# Wav2Letter Example using Google Speech Command Dataset

Google Speech Command Dataset can be found [here](https://www.kaggle.com/c/tensorflow-speech-recognition-challenge/data). This dataset was chosen as a quick and convenient way to test Wav2Letter performance

In [None]:
"""Data module. Downloads data. preprocess data. Data feeder pipeline

    TODO:
        * Build a more memory efficient data feeder pipeline
"""
import os
import numpy as np
from sonopy import mfcc_spec
from scipy.io.wavfile import read
from tqdm import tqdm_notebook as nb_tqdm
from tqdm import tqdm
import random
import pickle


class IntegerEncode:
    """Encodes labels into integers
    
    Args:
        labels (list): shape (n_samples, strings)
    """

    def __init__(self, labels):
        # reserve 0 for blank label
        self.char2index = {
            "-": 0,
            "pad":1
        }
        self.index2char = {
            0: "-",
            1: "pad"
        }
        self.grapheme_count = 2
        self.process(labels)
        self.max_label_seq = 6

    def process(self, labels):
        """builds the encoding values for labels
        
        Args:
            labels (list): shape (n_samples, strings)
        """
        strings = "".join(labels)
        for s in strings:
            if s not in self.char2index:
                self.char2index[s] = self.grapheme_count
                self.index2char[self.grapheme_count] = s
                self.grapheme_count += 1

    def convert_to_ints(self, label):
        """Convert into integers
        
        Args:
            label (str): string to encode
        
        Returns:
            list: shape (max_label_seq)
        """
        y = []
        for char in label:
            y.append(self.char2index[char])
        if len(y) < self.max_label_seq:
            diff = self.max_label_seq - len(y)
            pads = [self.char2index["pad"]] * diff
            y += pads
        return y

    def save(self, file_path):
        """Save integer encoder model as a pickle file

        Args:
            file_path (str): path to save pickle object
        """
        file_name = os.path.join(file_path, "int_encoder.pkl")
        with open(file_name, 'wb') as f:
            pickle.dump(self.__dict__, f)


def normalize(values):
    """Normalize values to mean 0 and std 1
    
    Args:
        values (np.array): shape (frame_len, features)
    
    Returns:
        np.array: normalized features
    """
    return (values - np.mean(values)) / np.std(values)


class GoogleSpeechCommand():
    """Data set can be found here 
        https://www.kaggle.com/c/tensorflow-speech-recognition-challenge/data
    """

    def __init__(self, data_path="../input/tensorflow-speech-recognition-challenge/train", sr=16000):
        self.data_path = data_path
        self.labels = [
            'right', 'eight', 'cat', 'tree', 'bed', 'happy', 'go', 'dog', 'no', 
            'wow', 'nine', 'left', 'stop', 'three', 'sheila', 'one', 'bird', 'zero',
            'seven', 'up', 'marvin', 'two', 'house', 'down', 'six', 'yes', 'on', 
            'five', 'off', 'four'
        ]
        self.intencode = IntegerEncode(self.labels)
        self.sr = sr
        self.max_frame_len = 225

    def get_data(self, progress_bar=True):
        """Currently returns mfccs and integer encoded data

        Returns:
            (list, list): 
                inputs shape (sample_size, frame_len, mfcs_features)
                targets shape (sample_size, seq_len)  seq_len is variable
        """
        pg = tqdm if progress_bar else lambda x: x

        inputs, targets = [], []
        meta_data = []
        for labels in self.labels:
            
            path = os.listdir(os.path.join(self.data_path, labels))
            for audio in path:
                if i<10:
                    audio_path = os.path.join(self.data_path, labels, audio)
                    print(audio_path)
                    print(labels)
                    meta_data.append((audio_path, labels))
                    i+=1
                else:
                    break
                audio_path = os.path.join(self.data_path, labels, audio)
                meta_data.append((audio_path, labels))
        
        random.shuffle(meta_data)

        for md in pg(meta_data):
            
            audio_path = md[0]
            labels = md[1]
            _, audio = read(audio_path)
            mfccs = mfcc_spec(
                audio, self.sr, window_stride=(160, 80),
                fft_size=512, num_filt=20, num_coeffs=13
            )
            mfccs = normalize(mfccs)
            diff = self.max_frame_len - mfccs.shape[0]
            mfccs = np.pad(mfccs, ((0, diff), (0, 0)), "constant")
            inputs.append(mfccs)

            target = self.intencode.convert_to_ints(labels)
            targets.append(target)
        return inputs, targets

    @staticmethod
    def save_vectors(file_path, x, y):
        """saves input and targets vectors as x.npy and y.npy
        
        Args:
            file_path (str): path to save numpy array
            x (list): inputs
            y (list): targets
        """
        x_file = os.path.join(file_path, "x")
        y_file = os.path.join(file_path, "y")
        np.save(x_file, np.asarray(x))
        np.save(y_file, np.asarray(y))

    @staticmethod
    def load_vectors(file_path):
        """load inputs and targets
        
        Args:
            file_path (str): path to load targets from
        
        Returns:
            inputs, targets: np.array, np.array
        """
        x_file = os.path.join(file_path, "x.npy")
        y_file = os.path.join(file_path, "y.npy")

        inputs = np.load(x_file)
        targets = np.load(y_file)
        return inputs, targets


# if __name__ == "__main__":
#     gs = GoogleSpeechCommand()
#     inputs, targets = gs.get_data()
#     gs.save_vectors("../input/", inputs, targets)
#     gs.intencode.save("../input/")
#     print("preprocessed and saved")

In [None]:
#train
"""Trains Wav2Letter model using speech data
    
    TODO:
        * show accuracy metrics
        * add more diverse datasets
        * train, val, test split
"""
import argparse
import torch
import torch.nn as nn
import torch.optim as optim



def train(batch_size, epochs):
    # load saved numpy arrays for google speech command
    gs = GoogleSpeechCommand()
    _inputs, _targets = gs.load_vectors("../input/")

    # paramters
    batch_size = batch_size
    mfcc_features = 13
    grapheme_count = gs.intencode.grapheme_count

    print("training google speech dataset")
    print("data size", len(_inputs))
    print("batch_size", batch_size)
    print("epochs", epochs)
    print("num_mfcc_features", mfcc_features)
    print("grapheme_count", grapheme_count)

    # torch tensors
    inputs = torch.Tensor(_inputs)
    targets = torch.IntTensor(_targets)

    print("input shape", inputs.shape)
    print("target shape", targets.shape)

    # Initialize model, loss, optimizer
    model = Wav2Letter(mfcc_features, grapheme_count)
    print(model.layers)

    ctc_loss = nn.CTCLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Each mfcc feature is a channel
    # https://pytorch.org/docs/stable/nn.html#torch.nn.Conv1d
    # transpose (sample_size, in_frame_len, mfcc_features)
    # to      (sample_size, mfcc_features, in_frame_len)
    inputs = inputs.transpose(1, 2)
    print("transposed input", inputs.shape)

    model.fit(inputs, targets, optimizer, ctc_loss, batch_size, epoch=epochs)

    sample = inputs[0]
    sample_target = targets[0]
    
    log_probs = model.eval(sample)
    output = GreedyDecoder(log_probs)

    print("sample target", sample_target)
    print("predicted", output)


# if __name__ == "__main__":
#     parser = argparse.ArgumentParser(description='Wav2Letter')
#     parser.add_argument('--batch-size', type=int, default=64, metavar='N',
#                         help='input batch size for training (default: 64)')
#     parser.add_argument('--epochs', type=int, default=100, metavar='N',
#                         help='total epochs (default: 100)')

#     args = parser.parse_args()

#     batch_size = args.batch_size
#     epochs = args.epochs
#     train(batch_size, epochs)

In [None]:
"""Beam Decoder Module.

    TODO: 
        * write beam search decoder
        * use KenLM langauge model to aid decoding
        * add WER and CER metrics
"""
from torch import topk
from __future__ import print_function
from __future__ import division

import numpy as np
import pandas as pd
import re
import copy

#prefix tree
class Node:
    "class representing nodes in a prefix tree"
    def __init__(self):
        self.children={} # all child elements beginning with current prefix
        self.isWord=False # does this prefix represent a word
        
    def __str__(self):
        s=''
        for k in self.children.keys():
            s+=k
        return 'isWord: '+str(self.isWord)+'; children: '+s

class PrefixTree:
    "prefix tree"
    def __init__(self):
        self.root=Node()

    def addWord(self, text):
        "add word to prefix tree"
        node=self.root
        for i in range(len(text)):
            c=text[i] # current char
            if c not in node.children:
                node.children[c]=Node()
            node=node.children[c]
            isLast=(i+1==len(text))
            if isLast:
                node.isWord=True
                
    def addWords(self, words):
        for w in words:
            self.addWord(w)
                
                
    def getNode(self, text):
        
        "get node representing given text"
        node=self.root
        for c in text:
            if c in node.children:
                node=node.children[c]
            else:
                return None
        return node

        
    def isWord(self, text):
        node=self.getNode(text)
        if node:
            return node.isWord
        return False
        
    
    def getNextChars(self, text):
        "get all characters which may directly follow given text"
        chars=[]
        node=self.getNode(text)
        if node:
            for k in node.children.keys():
                chars.append(k)
        return chars
    
    
    def getNextWords(self, text):
        "get all words of which given text is a prefix (including the text itself, it is a word)"
        words=[]
        node=self.getNode(text)
        if node:
            nodes=[node]
            prefixes=[text]
            while len(nodes)>0:
                # put all children into list
                for k,v in nodes[0].children.items():
                    nodes.append(v)
                    prefixes.append(prefixes[0]+k)
                
                # is current node a word
                if nodes[0].isWord:
                    words.append(prefixes[0])
                
                # remove current node
                del nodes[0]
                del prefixes[0]
                
        return words
                
                
    def dump(self):
        nodes=[self.root]
        while len(nodes)>0:
            # put all children into list
            for v in nodes[0].children.values():
                nodes.append(v)
            
            # dump current node
            print(nodes[0])
                
            # remove from list
            del nodes[0]




class Optical:
    "optical score of beam"
    def __init__(self, prBlank=0, prNonBlank=0):
        self.prBlank=prBlank # prob of ending with a blank
        self.prNonBlank=prNonBlank # prob of ending with a non-blank


class Textual:
    "textual score of beam"
    def __init__(self, text=''):
        self.text=text
        self.wordHist=[] # history of words so far
        self.wordDev='' # developing word
        self.prUnnormalized=1.0
        self.prTotal=1.0


class Beam:
    "beam with text, optical and textual score"
    def __init__(self, lm, useNGrams):
        "creates genesis beam"
        self.optical=Optical(1.0, 0.0)
        self.textual=Textual('')
        self.lm=lm
        self.useNGrams=useNGrams
        
        
    def mergeBeam(self, beam):
        "merge probabilities of two beams with same text"
        
        if self.getText()!=beam.getText():
            raise Exception('mergeBeam: texts differ')
        
        self.optical.prNonBlank+=beam.getPrNonBlank()
        self.optical.prBlank+=beam.getPrBlank()
        
        
    def getText(self):
        return self.textual.text
        
    
    def getPrBlank(self):
        return self.optical.prBlank
    
    
    def getPrNonBlank(self):
        return self.optical.prNonBlank
    

    def getPrTotal(self):
        return self.getPrBlank()+self.getPrNonBlank()
    
    
    def getPrTextual(self):
        return self.textual.prTotal
    
    
    def getNextChars(self):
        return self.lm.getNextChars(self.textual.wordDev)
        
        
    def createChildBeam(self, newChar, prBlank, prNonBlank):
        "extend beam by new character and set optical score"
        beam=Beam(self.lm, self.useNGrams)
        
        # copy textual information
        beam.textual=copy.deepcopy(self.textual)
        beam.textual.text+=newChar
        
        # do textual calculations only if beam gets extended
        if newChar!='':
            if self.useNGrams: # use unigrams and bigrams 
            
                # if new char occurs inside a word
                if newChar in beam.lm.getWordChars():
                    beam.textual.wordDev+=newChar
                    nextWords=beam.lm.getNextWords(beam.textual.wordDev)
                    
                    # no complete word in text, then use unigram of all possible next words
                    numWords=len(beam.textual.wordHist)
                    prSum=0
                    if numWords==0:
                        for w in nextWords:
                            prSum+=beam.lm.getUnigramProb(w)
                    # take last complete word and sum up bigrams of all possible next words
                    else:
                        lastWord=beam.textual.wordHist[-1]
                        for w in nextWords:
                            prSum+=beam.lm.getBigramProb(lastWord, w)
                    beam.textual.prTotal=beam.textual.prUnnormalized*prSum
                    beam.textual.prTotal=beam.textual.prTotal**(1/(numWords+1)) if numWords>=1 else beam.textual.prTotal
                    
                # if new char does not occur inside a word
                else:
                    # if current word is not empty, add it to history
                    if beam.textual.wordDev!='':
                        beam.textual.wordHist.append(beam.textual.wordDev)
                        beam.textual.wordDev=''
                        
                        # score with unigram (first word) or bigram (all other words) probability
                        numWords=len(beam.textual.wordHist)
                        if numWords==1:
                            beam.textual.prUnnormalized*=beam.lm.getUnigramProb(beam.textual.wordHist[-1])
                            beam.textual.prTotal=beam.textual.prUnnormalized
                        elif numWords>=2:
                            beam.textual.prUnnormalized*=beam.lm.getBigramProb(beam.textual.wordHist[-2], beam.textual.wordHist[-1])
                            beam.textual.prTotal=beam.textual.prUnnormalized**(1/numWords)
            
            else: # don't use unigrams and bigrams, just keep wordDev up to date
                if newChar in beam.lm.getWordChars():
                    beam.textual.wordDev+=newChar
                else:
                    beam.textual.wordDev=''
        
        # set optical information
        beam.optical.prBlank=prBlank
        beam.optical.prNonBlank=prNonBlank
        return beam
        
        
    def __str__(self):
        return '"'+self.getText()+'"'+';'+str(self.getPrTotal())+';'+str(self.getPrTextual())+';'+str(self.textual.prUnnormalized)


class BeamList:
    "list of beams at specific time-step"
    def __init__(self):
        self.beams={}
        

    def addBeam(self, beam):
        "add or merge new beam into list"
        # add if text not yet known
        if beam.getText() not in self.beams:
            self.beams[beam.getText()]=beam
        # otherwise merge with existing beam
        else:
            self.beams[beam.getText()].mergeBeam(beam)
        
        
    def getBestBeams(self, num):
        "return best beams, specify the max. number of beams to be returned (beam width)"
        u=[v for (_,v) in self.beams.items()]
        lmWeight=1
        return sorted(u, reverse=True, key=lambda x:x.getPrTotal()*(x.getPrTextual()**lmWeight))[:num]
        
        
    def deletePartialBeams(self, lm):
        "delete beams for which last word is not finished"
        for (k,v) in self.beams.items():
            lastWord=v.textual.wordDev
            if (lastWord!='') and (not lm.isWord(lastWord)):
                del self.beams[k]
    
    
    def completeBeams(self, lm):
        "complete beams such that last word is complete word"
        for (_,v) in self.beams.items():
            lastPrefix=v.textual.wordDev
            if lastPrefix=='' or lm.isWord(lastPrefix):
                continue
            
            # get word candidates for this prefix
            words=lm.getNextWords(lastPrefix)
            # if there is just one candidate, then the last prefix can be extended to 
            if len(words)==1:
                word=words[0]
                v.textual.text+=word[len(lastPrefix)-len(word):]


    def dump(self):
        for k in self.beams.keys():
            print(unicode(self.beams[k]).encode('ascii', 'replace')) # map to ascii if possible (for py2 and windows)



class LanguageModel:
    "unigram/bigram LM, add-k smoothing"
    def __init__(self, corpus, chars, wordChars):
        "read text from filename, specify chars which are contained in dataset, specify chars which form words"
        # read from file
        self.wordCharPattern='['+wordChars+']'
        self.wordPattern=self.wordCharPattern+'+'
        words=re.findall(self.wordPattern, corpus)
#         print('refindall')
#         print(corpus)
#         print(chars)
#         print(wordChars)
        uniqueWords=list(set(words)) # make unique
        self.numWords=len(words)
        self.numUniqueWords=len(uniqueWords)
        self.smoothing=True
        self.addK=1.0 if self.smoothing else 0.0
        
        # create unigrams
        self.unigrams={}
        for w in words:
            w=w.lower()
            if w not in self.unigrams:
                self.unigrams[w]=0
            self.unigrams[w]+=1/self.numWords
        #print('unigrams')
        #print(self.unigrams)
        # create unnormalized bigrams
        bigrams={}
        for i in range(len(words)-1):
            w1=words[i].lower()
            w2=words[i+1].lower()
            if w1 not in bigrams:
                bigrams[w1]={}
            if w2 not in bigrams[w1]:
                bigrams[w1][w2]=self.addK # add-K
            bigrams[w1][w2]+=1
        #print('bigrams')
        #print(bigrams)	
        #normalize bigrams 
        for w1 in bigrams.keys():
            # sum up
            probSum=self.numUniqueWords*self.addK # add-K smoothing
            for w2 in bigrams[w1].keys():
                probSum+=bigrams[w1][w2]
            # and divide
            for w2 in bigrams[w1].keys():
                bigrams[w1][w2]/=probSum
        self.bigrams=bigrams
        #print('normalized bigrams')
        #print(self.bigrams)
        # create prefix tree
        self.tree=PrefixTree() # create empty tree
        self.tree.addWords(words) # add all unique words to tree
        
        # list of all chars, word chars and nonword chars
        self.allChars=chars
        self.wordChars=wordChars
        self.nonWordChars=str().join(set(chars)-set(re.findall(self.wordCharPattern, chars))) # else calculate those chars
    
        
    def getNextWords(self, text):
        "text must be prefix of a word"
        return self.tree.getNextWords(text)
        
        
    def getNextChars(self, text):
        "text must be prefix of a word"
        nextChars=str().join(self.tree.getNextChars(text))
        
        # if in between two words or if word ends, add non-word chars
        if (text=='') or (self.isWord(text)):
            nextChars+=self.getNonWordChars()
            
        return nextChars

        
    def getWordChars(self):
        return self.wordChars

        
    def getNonWordChars(self):
        return self.nonWordChars
        
        
    def getAllChars(self):
        return self.allChars
    
    
    def isWord(self, text):
        return self.tree.isWord(text)
        
    
    def getUnigramProb(self, w):
        "prob of seeing word w."
        w=w.lower()
        val=self.unigrams.get(w)
        if val!=None:
            return val
        return 0
        
    
    def getBigramProb(self, w1, w2):
        "prob of seeing words w1 w2 next to each other."
        w1=w1.lower()
        w2=w2.lower()
        val1=self.bigrams.get(w1)
        if val1!=None:
            val2=val1.get(w2)
            if val2!=None:
                return val2
            return self.addK/(self.getUnigramProb(w1)*self.numUniqueWords+self.numUniqueWords)
        return 0


def wordBeamSearch(mat, beamWidth, lm, useNGrams):
    "decode matrix, use given beam width and language model"
    chars=lm.getAllChars()
#     print('all chars')
#     print(chars)
    blankIdx=len(chars) # blank label is supposed to be last label in RNN output
    maxT,_=mat.shape # shape of RNN output: TxC
    print(mat.shape)
    genesisBeam=Beam(lm, useNGrams) # empty string
    last=BeamList() # list of beams at time-step before beginning of RNN output
    last.addBeam(genesisBeam) # start with genesis beam
    
    # go over all time-steps
    for t in range(maxT):
        curr=BeamList() # list of beams at current time-step
    
        # go over best beams
        bestBeams=last.getBestBeams(beamWidth) # get best beams
        for beam in bestBeams:
            # calc probability that beam ends with non-blank
            prNonBlank=0
            if beam.getText()!='':
                
                # char at time-step t must also occur at t-1
                
                labelIdx=chars.index(beam.getText()[-1])
                #print(labelIdx)
                prNonBlank=beam.getPrNonBlank()*mat[t, labelIdx]
            
            # calc probability that beam ends with blank
            prBlank=beam.getPrTotal()*mat[t, blankIdx]
            
            # save result
            curr.addBeam(beam.createChildBeam('', prBlank, prNonBlank))
            
            # extend current beam with characters according to language model
            nextChars=beam.getNextChars()
            for c in nextChars:
                # extend current beam with new character
                labelIdx=chars.index(c)
                if beam.getText()!='' and beam.getText()[-1]==c: 
                    prNonBlank=mat[t, labelIdx]*beam.getPrBlank() # same chars must be separated by blank
                else:
                    prNonBlank=mat[t, labelIdx]*beam.getPrTotal() # different chars can be neighbours
                    
                # save result
                curr.addBeam(beam.createChildBeam(c, 0, prNonBlank))
        
        # move current beams to next time-step
        last=curr
        
    # return most probable beam
    last.completeBeams(lm)
    bestBeams=last.getBestBeams(1) # sort by probability
    return bestBeams[0].getText()

#
# def loadFromCSV(fn):
#     "load matrix from csv file. Last entry in row terminated by semicolon."
#     mat=np.genfromtxt(fn, delimiter=';')[:,:-1]
#     mat=softmax(mat)
#     return mat

def softmax(mat):
    "calc softmax such that labels per time-step form probability distribution"
    # dim0=t, dim1=c
    maxT,_=mat.shape
    res=np.zeros(mat.shape)
    for t in range(maxT):
        y=mat[t,:]
        maxValue = np.max(y)
        e=np.exp(y - maxValue)
        s=np.sum(e)
        res[t,:]=e/s
        
    return res
def GreedyDecoder(ctc_matrix, blank_label=0):
    """Greedy Decoder. Returns highest probability of
        class labels for each timestep

        # TODO: collapse blank labels

    Args:
        ctc_matrix (torch.Tensor): 
            shape (1, num_classes, output_len)
        blank_label (int): blank labels to collapse
    
    Returns:
        torch.Tensor: class labels per time step.
         shape (ctc timesteps)
    """
    top = topk(ctc_matrix, k=1, dim=1)
#     print(top)
    return top[1][0][0]

In [None]:
s='right\neight\ncat\ntree\nbed\nhappy\ngo\ndog\nno\nwow\nnine\nleft\nstop\nthree\nsheila\none\nbird\nzero\nseven\nup\nmarvin\ntwo\nhouse\ndown\nsix\nyes\non\nfive\noff\nfour'
l=' _rightecabdpyonwlfszvum'
# print(len(l))
# print(len(s))
k=' _rightecabdpyonwlfszvum.'
# print(s)
testLM=LanguageModel(s,l,k)

#testMat=np.array([[0.3, 0.1, 0, 0.6], [0.3, 0.1, 0, 0.6]])
#testMat=loadFromCSV('/home/aloui/Desktop/eurecom/semesterproject/Wav2Letter-master/CTCWordBeamSearch/data/bentham/mat_0.csv')	
testMat=np.array([[-1.7593e+01, -1.7074e+01, -1.6497e+01, -1.5797e+01, -1.4067e+01,
      -1.2747e+01, -1.4210e+01, -1.1112e+01, -1.0233e+01, -1.1682e+01,
      -1.4552e+01, -1.2531e+01, -6.7294e+00, -4.3740e+00, -1.0335e-03,
      -1.4305e-06],
     [-1.0432e+02, -1.0189e+02, -9.9412e+01, -9.6165e+01, -8.6178e+01,
      -7.3011e+01, -6.4005e+01, -5.0625e+01, -4.1224e+01, -3.3498e+01,
      -2.7419e+01, -1.8274e+01, -1.1610e+01, -1.2879e+01, -1.3572e+01,
      -1.6149e+01],
     [-8.5649e+01, -8.3730e+01, -8.1745e+01, -7.8909e+01, -6.8729e+01,
      -5.5506e+01, -4.6494e+01, -3.2475e+01, -2.3045e+01, -2.0536e+01,
      -2.3952e+01, -2.6844e+01, -3.0908e+01, -4.0220e+01, -4.3960e+01,
      -3.9924e+01],
     [-1.3802e+02, -1.3524e+02, -1.3243e+02, -1.2812e+02, -1.1087e+02,
      -8.6395e+01, -6.6041e+01, -3.9905e+01, -1.5520e+01, -3.7718e-03,
      -1.2390e-03, -6.9552e+00, -1.7081e+01, -3.3094e+01, -4.5169e+01,
      -4.8660e+01],
     [-9.4878e+01, -9.2692e+01, -9.0491e+01, -8.7345e+01, -7.5866e+01,
      -6.0044e+01, -4.9246e+01, -3.6678e+01, -2.8755e+01, -2.9582e+01,
      -3.9844e+01, -4.9655e+01, -5.9207e+01, -7.0509e+01, -7.0579e+01,
      -5.8878e+01],
     [-1.2235e+01, -1.1861e+01, -1.1515e+01, -1.0706e+01, -5.7388e+00,
      -1.1462e-01, -6.2583e-05, -1.0385e-02, -8.1400e+00, -1.6062e+01,
      -2.2975e+01, -2.8231e+01, -3.6365e+01, -4.7781e+01, -4.9978e+01,
      -3.6779e+01],
     [-4.0660e+01, -3.9494e+01, -3.8334e+01, -3.6669e+01, -3.0792e+01,
      -2.3874e+01, -2.2574e+01, -1.9128e+01, -2.0773e+01, -2.2111e+01,
      -2.3973e+01, -2.0269e+01, -2.1135e+01, -3.1099e+01, -4.0071e+01,
      -3.9298e+01],
     [-5.4299e+01, -5.2935e+01, -5.1557e+01, -4.9247e+01, -3.8925e+01,
      -2.4997e+01, -1.6878e+01, -4.5741e+00, -3.2789e-04, -5.5843e+00,
      -1.3035e+01, -1.5696e+01, -2.0302e+01, -3.2494e+01, -4.1780e+01,
      -3.7952e+01],
     [-9.9552e+01, -9.7791e+01, -9.6043e+01, -9.3321e+01, -8.2187e+01,
      -6.7042e+01, -5.7514e+01, -4.5464e+01, -3.9456e+01, -3.7595e+01,
      -4.2671e+01, -4.7396e+01, -5.1433e+01, -5.9513e+01, -5.8693e+01,
      -5.0127e+01],
     [-5.9186e+01, -5.8068e+01, -5.6943e+01, -5.5073e+01, -4.6770e+01,
      -3.6241e+01, -3.1918e+01, -2.6652e+01, -2.5285e+01, -2.1906e+01,
      -2.0537e+01, -1.1101e+01, -1.9169e-02, -1.2684e-02, -6.8767e+00,
      -1.3521e+01],
     [-9.8360e+01, -9.6137e+01, -9.3839e+01, -9.0783e+01, -8.1057e+01,
      -6.8911e+01, -6.2822e+01, -5.4668e+01, -5.1669e+01, -5.1261e+01,
      -5.3135e+01, -5.2876e+01, -5.3515e+01, -5.9671e+01, -6.1255e+01,
      -5.5138e+01],
     [-9.6387e+01, -9.4265e+01, -9.2152e+01, -8.9309e+01, -8.0050e+01,
      -6.8218e+01, -6.1441e+01, -5.3107e+01, -5.0363e+01, -4.7982e+01,
      -4.6926e+01, -4.4706e+01, -4.6223e+01, -5.4721e+01, -6.2478e+01,
      -6.3263e+01],
     [-1.2537e+02, -1.2287e+02, -1.2031e+02, -1.1654e+02, -1.0200e+02,
      -8.2215e+01, -6.7355e+01, -4.9493e+01, -3.9044e+01, -3.0222e+01,
      -2.3936e+01, -1.7541e+01, -1.5676e+01, -2.6550e+01, -4.0139e+01,
      -4.8997e+01],
     [-1.0041e+02, -9.8510e+01, -9.6565e+01, -9.3949e+01, -8.5550e+01,
      -7.4019e+01, -6.6600e+01, -5.4862e+01, -4.6085e+01, -3.6859e+01,
      -3.2397e+01, -2.6417e+01, -2.4829e+01, -3.3564e+01, -4.1811e+01,
      -4.3082e+01],
     [-8.7072e+01, -8.5339e+01, -8.3649e+01, -8.1055e+01, -7.0073e+01,
      -5.5333e+01, -4.5278e+01, -3.2008e+01, -2.3462e+01, -1.8506e+01,
      -1.7126e+01, -1.7311e+01, -2.3693e+01, -3.7960e+01, -4.7618e+01,
      -5.0385e+01],
     [-8.7825e+01, -8.5692e+01, -8.3554e+01, -8.0611e+01, -7.0660e+01,
      -5.8025e+01, -5.1754e+01, -4.3738e+01, -3.9813e+01, -3.7932e+01,
      -4.0408e+01, -3.6871e+01, -3.3969e+01, -4.0756e+01, -5.0601e+01,
      -5.4408e+01],
     [-9.3954e+01, -9.1868e+01, -8.9801e+01, -8.6735e+01, -7.4910e+01,
      -5.8734e+01, -4.7331e+01, -3.3974e+01, -2.6063e+01, -2.5609e+01,
      -3.3082e+01, -4.1594e+01, -5.1939e+01, -6.6132e+01, -7.1978e+01,
      -6.4930e+01],
     [-1.0768e+02, -1.0579e+02, -1.0389e+02, -1.0104e+02, -9.0296e+01,
      -7.5215e+01, -6.4175e+01, -4.9190e+01, -3.4810e+01, -1.7655e+01,
      -6.6961e+00, -9.7311e-04, -4.0296e+00, -1.9473e+01, -3.2847e+01,
      -3.7509e+01],
     [-6.4869e+01, -6.3313e+01, -6.1785e+01, -5.9542e+01, -5.0833e+01,
      -4.0101e+01, -3.5039e+01, -2.9649e+01, -2.9994e+01, -3.3351e+01,
      -4.0330e+01, -4.2615e+01, -4.5620e+01, -5.3958e+01, -5.7320e+01,
      -5.2397e+01],
     [-4.8876e-06, -7.1525e-06, -1.0014e-05, -2.2530e-05, -3.2248e-03,
      -2.2230e+00, -9.6915e+00, -1.3734e+01, -1.9076e+01, -2.2994e+01,
      -2.7448e+01, -3.1231e+01, -3.9155e+01, -5.2370e+01, -5.1645e+01,
      -3.4997e+01],
     [-4.8098e+01, -4.6839e+01, -4.5527e+01, -4.3635e+01, -3.6971e+01,
      -2.8894e+01, -2.7736e+01, -2.6435e+01, -3.2290e+01, -4.1844e+01,
      -5.2870e+01, -5.9240e+01, -6.2683e+01, -6.7746e+01, -5.9529e+01,
      -4.6633e+01],
     [-1.7368e+02, -1.6963e+02, -1.6551e+02, -1.5995e+02, -1.4144e+02,
      -1.1520e+02, -9.2422e+01, -6.4522e+01, -3.9035e+01, -2.8712e+01,
      -3.4158e+01, -4.1348e+01, -5.0714e+01, -6.1566e+01, -6.3087e+01,
      -5.7808e+01],
     [-1.3419e+02, -1.3163e+02, -1.2914e+02, -1.2557e+02, -1.1194e+02,
      -9.2902e+01, -7.8236e+01, -5.8599e+01, -4.1839e+01, -3.0906e+01,
      -3.1228e+01, -3.5177e+01, -4.4230e+01, -5.9774e+01, -6.6112e+01,
      -6.3609e+01],
     [-1.1081e+02, -1.0840e+02, -1.0584e+02, -1.0300e+02, -9.7691e+01,
      -9.2273e+01, -9.1687e+01, -8.7705e+01, -8.7443e+01, -8.8239e+01,
      -9.0865e+01, -9.0961e+01, -8.9355e+01, -8.9845e+01, -7.9617e+01,
      -6.5430e+01],
     [-1.4511e+02, -1.4197e+02, -1.3884e+02, -1.3450e+02, -1.1960e+02,
      -9.9066e+01, -8.2916e+01, -6.2148e+01, -4.2919e+01, -2.9997e+01,
      -2.8739e+01, -3.1952e+01, -4.0848e+01, -5.7141e+01, -6.7129e+01,
      -6.4137e+01]]).T
# print('shape of matrix')
# print(testMat.shape)
testBW=25
res=wordBeamSearch(softmax(testMat), testBW, testLM, False)
print('Result: "'+res+'"')


In [None]:
#model
from __future__ import print_function
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class Wav2Letter(nn.Module):
    """Wav2Letter Speech Recognition model
        Architecture is based off of Facebooks AI Research paper
        https://arxiv.org/pdf/1609.03193.pdf
        This specific architecture accepts mfcc or
        power spectrums speech signals

        TODO: use cuda if available

        Args:
            num_features (int): number of mfcc features
            num_classes (int): number of unique grapheme class labels
    """

    def __init__(self, num_features, num_classes):
        super(Wav2Letter, self).__init__()

        # Conv1d(in_channels, out_channels, kernel_size, stride)
        self.layers = nn.Sequential(
            nn.Conv1d(num_features, 250, 48, 2),
            torch.nn.ReLU(),
            nn.Conv1d(250, 250, 7),
            torch.nn.ReLU(),
            nn.Conv1d(250, 250, 7),
            torch.nn.ReLU(),
            nn.Conv1d(250, 250, 7),
            torch.nn.ReLU(),
            nn.Conv1d(250, 250, 7),
            torch.nn.ReLU(),
            nn.Conv1d(250, 250, 7),
            torch.nn.ReLU(),
            nn.Conv1d(250, 250, 7),
            torch.nn.ReLU(),
            nn.Conv1d(250, 250, 7),
            torch.nn.ReLU(),
            nn.Conv1d(250, 2000, 32),
            torch.nn.ReLU(),
            nn.Conv1d(2000, 2000, 1),
            torch.nn.ReLU(),
            nn.Conv1d(2000, num_classes, 1),
        )

    def forward(self, batch):
        """Forward pass through Wav2Letter network than 
            takes log probability of output

        Args:
            batch (int): mini batch of data
             shape (batch, num_features, frame_len)

        Returns:
            log_probs (torch.Tensor):
                shape  (batch_size, num_classes, output_len)
        """
        # y_pred shape (batch_size, num_classes, output_len)
        y_pred = self.layers(batch)

        # compute log softmax probability on graphemes
        log_probs = F.log_softmax(y_pred, dim=1)

        return log_probs

    def fit(self, inputs, output, optimizer, ctc_loss, batch_size, epoch, print_every=50):
        """Trains Wav2Letter model.

        Args:
            inputs (torch.Tensor): shape (sample_size, num_features, frame_len)
            output (torch.Tensor): shape (sample_size, seq_len)
            optimizer (nn.optim): pytorch optimizer
            ctc_loss (ctc_loss_fn): ctc loss function
            batch_size (int): size of mini batches
            epoch (int): number of epochs
            print_every (int): every number of steps to print loss
        """

        total_steps = math.ceil(len(inputs) / batch_size)
        seq_length = output.shape[1]

        for t in range(epoch):

            samples_processed = 0
            avg_epoch_loss = 0

            for step in range(total_steps):
                optimizer.zero_grad()
                batch = \
                    inputs[samples_processed:batch_size + samples_processed]

                # log_probs shape (batch_size, num_classes, output_len)
                log_probs = self.forward(batch)

                # CTC_Loss expects input shape
                # (input_length, batch_size, num_classes)
                log_probs = log_probs.transpose(1, 2).transpose(0, 1)

                # CTC arguments
                # https://pytorch.org/docs/master/nn.html#torch.nn.CTCLoss
                # better definitions for ctc arguments
                # https://discuss.pytorch.org/t/ctcloss-with-warp-ctc-help/8788/3
                mini_batch_size = len(batch)
                targets = output[samples_processed: mini_batch_size + samples_processed]

                input_lengths = torch.full((mini_batch_size,), log_probs.shape[0], dtype=torch.long)
                target_lengths = torch.IntTensor([target.shape[0] for target in targets])

                loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)

                avg_epoch_loss += loss.item()

                loss.backward()
                optimizer.step()

                samples_processed += mini_batch_size

                if step % print_every == 0:
                    print("epoch", t + 1, ":" , "step", step + 1, "/", total_steps, ", loss ", loss.item())

            print("epoch", t + 1, "average epoch loss", avg_epoch_loss / total_steps)

    def eval(self, sample):
        """Evaluate model given a single sample

        Args:
            sample (torch.Tensor): shape (n_features, frame_len)

        Returns:
            log probabilities (torch.Tensor):
                shape (n_features, output_len)
        """
        _input = sample.reshape(1, sample.shape[0], sample.shape[1])
        log_prob = self.forward(_input)
        return log_prob

## Load Data

In [None]:

# using google's speech command dataset
gs = GoogleSpeechCommand()
_inputs, _targets = gs.load_vectors("../input/comand-dat/")

In [None]:
# if __name__ == "__main__":
#     gs = GoogleSpeechCommand()
#     inputs, targets = gs.get_data()
#     gs.save_vectors("./speech_data", inputs, targets)
#     gs.intencode.save("./speech_data")
#     print("preprocessed and saved")


In [None]:
mfcc_features = 13
grapheme_count = 25
grapheme_count

In [None]:
#del _inputs
#del _targets

In [None]:
import torch
inputs = torch.Tensor(_inputs)
targets = torch.IntTensor(_targets)

In [None]:
print(inputs.shape)
print(targets.shape)

## Build Model

In [None]:
import torch.nn as nn
import torch.optim as optim

model = Wav2Letter(mfcc_features, grapheme_count)
print(model.layers)

ctc_loss = nn.CTCLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

## Train

In [None]:
# Each mfcc feature is a channel
# https://pytorch.org/docs/stable/nn.html#torch.nn.Conv1d
# transpose (sample_size, in_frame_len, mfcc_features)
# to      (sample_size, mfcc_features, in_frame_len)
inputs = inputs.transpose(1, 2)
print(inputs.shape)

In [None]:
# do short training run
batch_size = 1000
model.fit(inputs, targets, optimizer, ctc_loss, batch_size, epoch=10)

## Evaluate

In [None]:
gs.intencode.index2char

In [None]:

sample = inputs[6]
sample_target = targets[6]

print(sample.shape)

In [None]:
torch.save(model.state_dict(),'wav2letter.pt')


In [None]:
_inputs[0]

In [None]:
#model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load('../input/kernelc6ea63747a/wav2letter.pt'))
log_prob=model.eval(inputs[3])
output = GreedyDecoder(log_prob)
print(output)

In [None]:
log_prob =model.eval(inputs[1])

output = GreedyDecoder(log_prob)
print(output)

In [None]:
mo=''
for j in output.numpy() :
    if j !=1 and j!=0:
        
        mo=mo+gs.intencode.index2char[j]
print(mo)

In [None]:
mo=''
for j in _targets[2] :
    if j !=1  :
        
        mo=mo+gs.intencode.index2char[j]
print(mo)

In [None]:
s='right\neight\ncat\ntree\nbed\nhappy\ngo\ndog\nno\nwow\nnine\nleft\nstop\nthree\nsheila\none\nbird\nzero\nseven\nup\nmarvin\ntwo\nhouse\ndown\nsix\nyes\non\nfive\noff\nfour'
l=' _rightecabdpyonwlfszvum'
# print(len(l))
# print(len(s))
k=' _rightecabdpyonwlfszvum'
# print(s)
testLM=LanguageModel(s,l,k)
testMat=log_prob[0].detach().numpy().T
testBW=25
res=wordBeamSearch(softmax(testMat), testBW, testLM, False)
print('Result: "'+res+'"')


In [None]:
 _, audio = read('../input/bird-audio/0a7c2a8d_nohash_0.wav')
mfccs = mfcc_spec(
    audio, 16000, window_stride=(160, 80),
    fft_size=512, num_filt=20, num_coeffs=13
)
mfccs = normalize(mfccs)
diff = 225 - mfccs.shape[0]
mfccs = np.pad(mfccs, ((0, diff), (0, 0)), "constant")
print(mfccs.shape)

**Blank labels are 0, Pads are 1**

**As you can see,  If you remove the 0's and the 1's from the output the model predicted the correct labels!**