Skip to content

Commit

Permalink
Updates model outputs to make ctcdecode stop at the right timestep
Browse files Browse the repository at this point in the history
  • Loading branch information
ankit committed Jan 23, 2018
1 parent ae28c4f commit e81aad0
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
3 changes: 3 additions & 0 deletions decoder.py
Expand Up @@ -105,6 +105,8 @@ def convert_to_strings(self, out, seq_len):
for p, utt in enumerate(batch):
size = seq_len[b][p]
if size > 0:
# print(utt.size())
# print(utt[0:size])
transcript = ''.join(map(lambda x: self.int_to_char[x], utt[0:size]))
else:
transcript = ''
Expand Down Expand Up @@ -137,6 +139,7 @@ def decode(self, probs, sizes=None):
"""
probs = probs.cpu().transpose(0, 1).contiguous()
out, scores, offsets, seq_lens = self._decoder.decode(probs)
#print(seq_lens)

strings = self.convert_to_strings(out, seq_lens)
offsets = self.convert_tensor(offsets, seq_lens)
Expand Down
17 changes: 15 additions & 2 deletions test.py
@@ -1,6 +1,7 @@
import argparse

import numpy as np
import torch
from torch.autograd import Variable
from tqdm import tqdm

Expand Down Expand Up @@ -63,8 +64,10 @@
output_data = []
for i, (data) in tqdm(enumerate(test_loader), total=len(test_loader)):
inputs, targets, input_percentages, target_sizes = data
#print(input_percentages)

inputs = Variable(inputs, volatile=True)
with torch.no_grad():
inputs = Variable(inputs)

# unflatten targets
split_targets = []
Expand All @@ -73,14 +76,23 @@
split_targets.append(targets[offset:offset + size])
offset += size

#print(inputs.size())
if args.cuda:
inputs = inputs.cuda()


out = model(inputs)
out = out.transpose(0, 1) # TxNxH
#print(out.size())
seq_length = out.size(0)
sizes = input_percentages.mul_(int(seq_length)).int()

for i in range(out.size(1)):
start_idx = sizes[i] - 1
out.data[start_idx:,i, :] = torch.zeros(seq_length-start_idx, out.size(2))-1

#print((out.transpose(0,1)).data)

if decoder is None:
# add output to data array, and continue
output_data.append((out.data.cpu().numpy(), sizes.numpy()))
Expand All @@ -90,7 +102,8 @@
target_strings = target_decoder.convert_to_strings(split_targets)
wer, cer = 0, 0
for x in range(len(target_strings)):
transcript, reference = decoded_output[x][0], target_strings[x][0]
transcript, reference = decoded_output[x][0], target_strings[x][0].replace(" '", "'")
print("Prediction: {}\nReference: {}\n-------------------------".format(transcript, reference))
wer_inst = decoder.wer(transcript, reference) / float(len(reference.split()))
cer_inst = decoder.cer(transcript, reference) / float(len(reference))
wer += wer_inst
Expand Down
7 changes: 7 additions & 0 deletions tune_decoder.py
Expand Up @@ -60,11 +60,18 @@ def decode_dataset(logits, test_dataset, batch_size, lm_alpha, lm_beta, mesh_x,
out = torch.from_numpy(logits[i][0])
sizes = torch.from_numpy(logits[i][1])

#print(logits[i][0])
# seq_length = out.size(0)
# for i in range(out.size(1)):
# start_idx = sizes[i] - 1
# out.data[start_idx:,i, :] = torch.zeros(seq_length-start_idx, out.size(2))-1

decoded_output, _ = decoder.decode(out, sizes)
target_strings = target_decoder.convert_to_strings(split_targets)
wer, cer = 0, 0
for x in range(len(target_strings)):
transcript, reference = decoded_output[x][0], target_strings[x][0]
print("Prediction: {}\nReference: {}\n-------------------------".format(transcript, reference))
wer_inst = decoder.wer(transcript, reference) / float(len(reference.split()))
cer_inst = decoder.cer(transcript, reference) / float(len(reference))
wer += wer_inst
Expand Down

0 comments on commit e81aad0

Please sign in to comment.