In [None]:
import datetime
import tensorflow as tf
import numpy as np
from glob import glob
from tqdm import tqdm
import keras.backend as K

from utils.data_loader import data_loader
from model import get_cnn_lstm_t_dense

Get paths and create config. Note, you will have to download the data and preprocess it, in order to use following code!!!

In [None]:
train_folder = 'prep-train-clean/train/'
val_folder = 'prep-train-clean/val/'
train_folder = 'prep-train-clean-copy-shuffled/train/'
val_folder = 'prep-train-clean-copy-shuffled/val/'

Nt = 247000
Nv = 59500
x = np.load(train_folder + '0.npy')
input_length = np.load(train_folder + '0_len.npy')

y = np.load(train_folder + '0_target.npy')
label_length = np.load(train_folder + '0_target_len.npy')

In [None]:
# Training config
config = {
    'lr': 0.001,
    'patience': 6,
    'epochs': 100,
    'batch_size': 64,
    'hidden_dim': 384,
    't_step': 480,
    'signal_dim': 160,
    'lstm_layers': 1,
    'print_every': 100,
    'num_classes_internal': 32,
    'num_classes': 29,
    'chunk_size': 1024,
    'f_size': 32
}

config['steps_per_epoch_train'] = int(Nt/config['batch_size'])
config['steps_per_epoch_val'] = int(Nv/config['f_size'])

Initialize model

In [None]:
# init model
tf.reset_default_graph()

model = get_cnn_lstm_t_dense(config)
model.initialize()

In [None]:
# Get number of trainable paramters for comparison with keras
total_parameters = 0
for variable in tf.trainable_variables():
    # shape is an array of tf.Dimension
    shape = variable.get_shape()
    variable_parameters = 1
    for dim in shape:
        variable_parameters *= dim.value
    total_parameters += variable_parameters
print(total_parameters)

In [None]:
###### not needed for training
# save model for huawei compatibility check
#####################################
sess = tf.Session()
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())

output_graph_def = tf.graph_util.convert_variables_to_constants(
    sess,
    tf.get_default_graph().as_graph_def(),
    ['out']
    )

# save graph
with tf.gfile.GFile('best_model.pb', "wb") as f:
    f.write(output_graph_def.SerializeToString())
#####################################

In [None]:
model.inputs, model.outputs, model.targets, model.save_nodes

Define ctc loss and perform training

In [None]:
def ctc_lambda_func(args):
    y_pred, labels, input_length, label_length = args
    return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

In [None]:
# get data loader
data_loader_train = data_loader(train_folder, config, mode='train')
data_loader_val = data_loader(val_folder, config, mode='val')

In [None]:
############# Training ##########
#################################

# placeholder for training
# will not be part of production graph
the_labels = tf.placeholder(shape=(None, None), dtype=tf.float32)
input_lengths = tf.placeholder(shape=(None,), dtype=tf.int64)
label_lengths = tf.placeholder(shape=(None,), dtype=tf.int64)

# in order to use learning rate schedules
learning_rate = tf.placeholder(tf.float32, shape=[])

# define loss
ctc_cost = ctc_lambda_func(
    [model.outputs[0], the_labels, input_lengths, label_lengths]
)
ctc_cost = tf.reduce_mean(ctc_cost)

# optimizer
train_ops = tf.train.AdamOptimizer(learning_rate).minimize(ctc_cost)

# initialize session
sess = tf.Session()
sess.run(tf.global_variables_initializer())

overall_train_loss, overall_val_loss = [], []
run_id = str(datetime.datetime.now())
best_val_loss = np.inf
orig_lr = config['lr']
for i in range(config['epochs']):
    print('Epoch ' + str(i+1))

    # Training
    train_loss = []
    for j in tqdm(range(config['steps_per_epoch_train'])):

        batch_x, batch_x_len, batch_y, batch_y_len = data_loader_train.__next__()

        inp = { 
            model.inputs[0]: batch_x,
            model.inputs[1]: np.zeros((batch_x.shape[0], config['hidden_dim'])),
            the_labels: batch_y,
            input_lengths: batch_x_len,
            label_lengths: batch_y_len,
            learning_rate: config['lr']
        }

        cost = sess.run([ctc_cost, train_ops], inp)
        train_loss.append(cost[0])

    # Validation
    val_loss = []
    for j in tqdm(range(config['steps_per_epoch_val'])):
        
        batch_x, batch_x_len, batch_y, batch_y_len = data_loader_val.__next__()
    
        inp = { 
            model.inputs[0]: batch_x,
            model.inputs[1]: np.zeros((batch_x.shape[0], config['hidden_dim'])),
            the_labels: batch_y,
            input_lengths: batch_x_len,
            label_lengths: batch_y_len
        }
        
        cost = sess.run([ctc_cost], inp)
        
        val_loss.append(cost[0])

    t_loss = np.mean(train_loss)
    v_loss = np.mean(val_loss)
    print('Epoch ' + str(i+1) + ' | Training loss: ' + str(t_loss) + ' | Validation loss: ' + str(v_loss))
    overall_train_loss.append(t_loss)
    overall_val_loss.append(v_loss)

    if v_loss < best_val_loss:

        best_val_loss = v_loss
        
        # freeze graph for writing to .pb
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            tf.get_default_graph().as_graph_def(),
            ['out']
            )

        # save graph
        with tf.gfile.GFile('best_model ' + run_id + '.pb', "wb") as f:
            f.write(output_graph_def.SerializeToString())
    else:
        # reduce learning rate on plateau
        config['lr'] = 0.5*config['lr']
        print('new lr: ' + str(config['lr']))

### Test best model

No

In [None]:
from tensorflow.python.platform import gfile

tf.reset_default_graph()

# Load best graph
sess = tf.Session()
# insert ur model ...
model_filename = 'runs/2019-01-14 00:26:20.816162/best_model.pb'
with gfile.FastGFile(model_filename, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def)

In [None]:
in1 = sess.graph.get_tensor_by_name("import/Placeholder:0")
in2 = sess.graph.get_tensor_by_name("import/Placeholder_1:0")
out = sess.graph.get_tensor_by_name("import/out:0")

sess.run(tf.global_variables_initializer())

In [None]:
test_train = False

if test_train:
    batch_x, batch_x_len, batch_y, batch_y_len = data_loader_train.__next__()
else:
    batch_x, batch_x_len, batch_y, batch_y_len = data_loader_val.__next__()

In [None]:
sample_n = 0
nb = 1

x_ = batch_x

result = sess.run(out, {
    in1: x_[:nb],
    in2: np.zeros((nb, 256))
})

result = np.reshape(result, (nb, 480, 32))[:, :, :config['num_classes']]
sub_results_soft = softmax(result[:nb]) 

# prediction = model.predict(x_[sample_n].reshape(1, x_.shape[1], x_.shape[2]))
output_length = [len(batch_y[sample_n])]

pred_sample = np.reshape(sub_results_soft[sample_n], (1, result.shape[1], result.shape[2]))
res = K.ctc_decode(pred_sample, [320])
pred_ints = (K.eval(K.ctc_decode(
                pred_sample, output_length)[0][0])+1).flatten().tolist()

print('True transcription:\n' + '\n' + ''.join(int_sequence_to_text(batch_y[sample_n] + 1)))
print('-' * 80)
print('Predicted transcription:\n' + '\n' + ''.join(int_sequence_to_text(pred_ints)))
print('-' * 80)
print('Predicted transcription with LM:\n' + '\n' + wordBeamSearch(sub_results_soft[sample_n], 25, lm, False))

Helper functions for inference

In [None]:
# Util functions

char_map_str = "' abcdefghijklmnopqrstuvwxyz"

char_map = {}
index_map = {}
index = 0
for letter in char_map_str:
    char_map[letter] = index
    index_map[index + 1] = letter
    index += 1

def text_to_int_sequence(text):
    """ Convert text to an integer sequence """
    return [char_map[_] for _ in text]

def get_chunk_sizes(N, chunk_size):
    full = int(N/chunk_size)
    ch_sizes = [chunk_size for i in range(full)]
    diff = N - full*chunk_size
    if diff != 0:
        ch_sizes.append(diff)
    return ch_sizes

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x)
    return np.einsum('ijk,ij->ijk', e_x ,1/e_x.sum(axis=2))

char_map_str = "' abcdefghijklmnopqrstuvwxyz"

char_map = {}
index_map = {}
index = 0
for letter in char_map_str:
    char_map[letter] = index
    index_map[index + 1] = letter
    index += 1

def int_sequence_to_text(int_sequence):
    """ Convert an integer sequence to text """
    text = ''
    for i in int_sequence:
        if not i == 29:
            text += index_map[i]
    return text


def cnn_output_length(input_length, filter_size, border_mode, stride, dilation=1):
    dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
    if border_mode == 'same':
        output_length = input_length
    elif border_mode == 'valid':
        output_length = input_length - dilated_filter_size + 1
    return (output_length + stride - 1) // stride

class Node:
	"class representing nodes in a prefix tree"
	def __init__(self):
		self.children={} # all child elements beginning with current prefix
		self.isWord=False # does this prefix represent a word
		
	def __str__(self):
		s=''
		for (k,_) in self.children.items():
			s+=k
		return 'isWord: '+str(self.isWord)+'; children: '+s


class PrefixTree:
	"prefix tree"
	def __init__(self):
		self.root=Node()

	def addWord(self, text):
		"add word to prefix tree"
		node=self.root
		for i in range(len(text)):
			c=text[i] # current char
			if c not in node.children:
				node.children[c]=Node()
			node=node.children[c]
			isLast=(i+1==len(text))
			if isLast:
				node.isWord=True
				
	def addWords(self, words):
		for w in words:
			self.addWord(w)
				
				
	def getNode(self, text):
		"get node representing given text"
		node=self.root
		for c in text:
			if c in node.children:
				node=node.children[c]
			else:
				return None
		return node

		
	def isWord(self, text):
		node=self.getNode(text)
		if node:
			return node.isWord
		return False
		
	
	def getNextChars(self, text):
		"get all characters which may directly follow given text"
		chars=[]
		node=self.getNode(text)
		if node:
			for k,_ in node.children.items():
				chars.append(k)
		return chars
	
	
	def getNextWords(self, text):
		"get all words of which given text is a prefix (including the text itself, it is a word)"
		words=[]
		node=self.getNode(text)
		if node:
			nodes=[node]
			prefixes=[text]
			while len(nodes)>0:
				# put all children into list
				for k,v in nodes[0].children.items():
					nodes.append(v)
					prefixes.append(prefixes[0]+k)
				
				# is current node a word
				if nodes[0].isWord:
					words.append(prefixes[0])
				
				# remove current node
				del nodes[0]
				del prefixes[0]
				
		return words
				
				
	def dump(self):
		nodes=[self.root]
		while len(nodes)>0:
			# put all children into list
			for _,v in nodes[0].children.items():
				nodes.append(v)
			
			# dump current node
			print(nodes[0])
				
			# remove from list
			del nodes[0]
            
import re

class LanguageModel:
	"unigram/bigram LM, add-k smoothing"
	def __init__(self, corpus, chars, wordChars):
		"read text from filename, specify chars which are contained in dataset, specify chars which form words"
		# read from file
		self.wordCharPattern='['+wordChars+']'
		self.wordPattern=self.wordCharPattern+'+'
		words=re.findall(self.wordPattern, corpus)
		uniqueWords=list(set(words)) # make unique
		self.numWords=len(words)
		self.numUniqueWords=len(uniqueWords)
		self.smoothing=True
		self.addK=1.0 if self.smoothing else 0.0
		
		# create unigrams
		self.unigrams={}
		for w in words:
			w=w.lower()
			if w not in self.unigrams:
				self.unigrams[w]=0
			self.unigrams[w]+=1/self.numWords
		
		# create unnormalized bigrams
		bigrams={}
		for i in range(len(words)-1):
			w1=words[i].lower()
			w2=words[i+1].lower()
			if w1 not in bigrams:
				bigrams[w1]={}
			if w2 not in bigrams[w1]:
				bigrams[w1][w2]=self.addK # add-K
			bigrams[w1][w2]+=1
			
		#normalize bigrams 
		for w1 in bigrams.keys():
			# sum up
			probSum=self.numUniqueWords*self.addK # add-K smoothing
			for w2 in bigrams[w1].keys():
				probSum+=bigrams[w1][w2]
			# and divide
			for w2 in bigrams[w1].keys():
				bigrams[w1][w2]/=probSum
		self.bigrams=bigrams
		
		# create prefix tree
		self.tree=PrefixTree() # create empty tree
		self.tree.addWords(words) # add all unique words to tree
		
		# list of all chars, word chars and nonword chars
		self.allChars=chars
		self.wordChars=wordChars
		self.nonWordChars=str().join(set(chars)-set(re.findall(self.wordCharPattern, chars))) # else calculate those chars
	

	def getNextWords(self, text):
		"text must be prefix of a word"
		return self.tree.getNextWords(text)
		
		
	def getNextChars(self, text):
		"text must be prefix of a word"
		nextChars=str().join(self.tree.getNextChars(text))
		
		# if in between two words or if word ends, add non-word chars
		if (text=='') or (self.isWord(text)):
			nextChars+=self.getNonWordChars()
			
		return nextChars

		
	def getWordChars(self):
		return self.wordChars

		
	def getNonWordChars(self):
		return self.nonWordChars
		
		
	def getAllChars(self):
		return self.allChars
	
	
	def isWord(self, text):
		return self.tree.isWord(text)
		
	
	def getUnigramProb(self, w):
		"prob of seeing word w."
		w=w.lower()
		val=self.unigrams.get(w)
		if val!=None:
			return val
		return 0
		
	
	def getBigramProb(self, w1, w2):
		"prob of seeing words w1 w2 next to each other."
		w1=w1.lower()
		w2=w2.lower()
		val1=self.bigrams.get(w1)
		if val1!=None:
			val2=val1.get(w2)
			if val2!=None:
				return val2
			return self.addK/(self.getUnigramProb(w1)*self.numUniqueWords+self.numUniqueWords)
		return 0
    
import copy


class Optical:
	"optical score of beam"
	def __init__(self, prBlank=0, prNonBlank=0):
		self.prBlank=prBlank # prob of ending with a blank
		self.prNonBlank=prNonBlank # prob of ending with a non-blank


class Textual:
	"textual score of beam"
	def __init__(self, text=''):
		self.text=text
		self.wordHist=[] # history of words so far
		self.wordDev='' # developing word
		self.prUnnormalized=1.0
		self.prTotal=1.0


class Beam:
	"beam with text, optical and textual score"
	def __init__(self, lm, useNGrams):
		"creates genesis beam"
		self.optical=Optical(1.0, 0.0)
		self.textual=Textual('')
		self.lm=lm
		self.useNGrams=useNGrams
		
		
	def mergeBeam(self, beam):
		"merge probabilities of two beams with same text"
		
		if self.getText()!=beam.getText():
			raise Exception('mergeBeam: texts differ')
		
		self.optical.prNonBlank+=beam.getPrNonBlank()
		self.optical.prBlank+=beam.getPrBlank()
		
		
	def getText(self):
		return self.textual.text
		
	
	def getPrBlank(self):
		return self.optical.prBlank
	
	
	def getPrNonBlank(self):
		return self.optical.prNonBlank
	

	def getPrTotal(self):
		return self.getPrBlank()+self.getPrNonBlank()
	
	
	def getPrTextual(self):
		return self.textual.prTotal
	
	
	def getNextChars(self):
		return self.lm.getNextChars(self.textual.wordDev)
		
		
	def createChildBeam(self, newChar, prBlank, prNonBlank):
		"extend beam by new character and set optical score"
		beam=Beam(self.lm, self.useNGrams)
		
		# copy textual information
		beam.textual=copy.deepcopy(self.textual)
		beam.textual.text+=newChar
		
		# do textual calculations only if beam gets extended
		if newChar!='':
			if self.useNGrams: # use unigrams and bigrams 
			
				# if new char occurs inside a word
				if newChar in beam.lm.getWordChars():
					beam.textual.wordDev+=newChar
					nextWords=beam.lm.getNextWords(beam.textual.wordDev)
					
					# no complete word in text, then use unigram of all possible next words
					numWords=len(beam.textual.wordHist)
					prSum=0
					if numWords==0:
						for w in nextWords:
							prSum+=beam.lm.getUnigramProb(w)
					# take last complete word and sum up bigrams of all possible next words
					else:
						lastWord=beam.textual.wordHist[-1]
						for w in nextWords:
							prSum+=beam.lm.getBigramProb(lastWord, w)
					beam.textual.prTotal=beam.textual.prUnnormalized*prSum
					beam.textual.prTotal=beam.textual.prTotal**(1/(numWords+1)) if numWords>=1 else beam.textual.prTotal
					
				# if new char does not occur inside a word
				else:
					# if current word is not empty, add it to history
					if beam.textual.wordDev!='':
						beam.textual.wordHist.append(beam.textual.wordDev)
						beam.textual.wordDev=''
						
						# score with unigram (first word) or bigram (all other words) probability
						numWords=len(beam.textual.wordHist)
						if numWords==1:
							beam.textual.prUnnormalized*=beam.lm.getUnigramProb(beam.textual.wordHist[-1])
							beam.textual.prTotal=beam.textual.prUnnormalized
						elif numWords>=2:
							beam.textual.prUnnormalized*=beam.lm.getBigramProb(beam.textual.wordHist[-2], beam.textual.wordHist[-1])
							beam.textual.prTotal=beam.textual.prUnnormalized**(1/numWords)
			
			else: # don't use unigrams and bigrams, just keep wordDev up to date
				if newChar in beam.lm.getWordChars():
					beam.textual.wordDev+=newChar
				else:
					beam.textual.wordDev=''
		
		# set optical information
		beam.optical.prBlank=prBlank
		beam.optical.prNonBlank=prNonBlank
		return beam
		
		
	def __str__(self):
		return '"'+self.getText()+'"'+';'+str(self.getPrTotal())+';'+str(self.getPrTextual())+';'+str(self.textual.prUnnormalized)


class BeamList:
	"list of beams at specific time-step"
	def __init__(self):
		self.beams={}
		

	def addBeam(self, beam):
		"add or merge new beam into list"
		# add if text not yet known
		if beam.getText() not in self.beams:
			self.beams[beam.getText()]=beam
		# otherwise merge with existing beam
		else:
			self.beams[beam.getText()].mergeBeam(beam)
		
		
	def getBestBeams(self, num):
		"return best beams, specify the max. number of beams to be returned (beam width)"
		u=[v for (_,v) in self.beams.items()]
		lmWeight=1
		return sorted(u, reverse=True, key=lambda x:x.getPrTotal()*(x.getPrTextual()**lmWeight))[:num]
		
		
	def deletePartialBeams(self, lm):
		"delete beams for which last word is not finished"
		for (k,v) in self.beams.items():
			lastWord=v.textual.wordDev
			if (lastWord!='') and (not lm.isWord(lastWord)):
				del self.beams[k]
	
	
	def completeBeams(self, lm):
		"complete beams such that last word is complete word"
		for (_,v) in self.beams.items():
			lastPrefix=v.textual.wordDev
			if lastPrefix=='' or lm.isWord(lastPrefix):
				continue
			
			# get word candidates for this prefix
			words=lm.getNextWords(lastPrefix)
			# if there is just one candidate, then the last prefix can be extended to 
			if len(words)==1:
				word=words[0]
				v.textual.text+=word[len(lastPrefix)-len(word):]


	def dump(self):
		for k in self.beams.keys():
			print(unicode(self.beams[k]).encode('ascii', 'replace')) # map to ascii if possible (for py2 and windows)

def wordBeamSearch(mat, beamWidth, lm, useNGrams):
	"decode matrix, use given beam width and language model"
	chars=lm.getAllChars()
	blankIdx=len(chars) # blank label is supposed to be last label in RNN output
	maxT,_=mat.shape # shape of RNN output: TxC
	
	genesisBeam=Beam(lm, useNGrams) # empty string
	last=BeamList() # list of beams at time-step before beginning of RNN output
	last.addBeam(genesisBeam) # start with genesis beam
	
	# go over all time-steps
	for t in range(maxT):
		curr=BeamList() # list of beams at current time-step
		
		# go over best beams
		bestBeams=last.getBestBeams(beamWidth) # get best beams
		for beam in bestBeams:
			# calc probability that beam ends with non-blank
			prNonBlank=0
			if beam.getText()!='':
				# char at time-step t must also occur at t-1
				labelIdx=chars.index(beam.getText()[-1])
				prNonBlank=beam.getPrNonBlank()*mat[t, labelIdx]
			
			# calc probability that beam ends with blank
			prBlank=beam.getPrTotal()*mat[t, blankIdx]
			
			# save result
			curr.addBeam(beam.createChildBeam('', prBlank, prNonBlank))
			
			# extend current beam with characters according to language model
			nextChars=beam.getNextChars()
			for c in nextChars:
				# extend current beam with new character
				labelIdx=chars.index(c)
				if beam.getText()!='' and beam.getText()[-1]==c: 
					prNonBlank=mat[t, labelIdx]*beam.getPrBlank() # same chars must be separated by blank
				else:
					prNonBlank=mat[t, labelIdx]*beam.getPrTotal() # different chars can be neighbours
					
				# save result
				curr.addBeam(beam.createChildBeam(c, 0, prNonBlank))
		
		# move current beams to next time-step
		last=curr
		
	# return most probable beam
	last.completeBeams(lm)
	bestBeams=last.getBestBeams(1) # sort by probability
	return bestBeams[0].getText()

corpus = open('word_list/cleaned_words_10k_nofreq.txt', 'r').read()

lm = LanguageModel(corpus, "' abcdefghijklmnopqrstuvwxyz", "abcdefghijklmnopqrstuvwxyz")