In [1]:
import torch
import torch.nn as nn
import numpy as np
import phoneme_list as pl
from train import PhonDataset, collate_phon, PackedPhonModel, load_ckpt
import Levenshtein as L
from torch.utils.data import Dataset, DataLoader, TensorDataset
from ctcdecode import CTCBeamDecoder


# the p_map list
p_map = pl.PHONEME_MAP
p_map.append('%')
print(len(p_map))
print(p_map.index('%'))
print(p_map)


# validation loader
val_data_path = './../data/wsj0_dev.npy'
val_label_path = './../data/wsj0_dev_merged_labels.npy'
val_data = np.load(val_data_path, encoding='bytes')
val_label = np.load(val_label_path)
val_dataset = PhonDataset(val_data, val_label)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=1, collate_fn=collate_phon)


# load model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

path = './../result/id_0'
model, _ = load_ckpt(path)
model.to(DEVICE)



47
46
['_', '+', '~', '!', '-', '@', 'a', 'A', 'h', 'o', 'w', 'y', 'b', 'c', 'd', 'D', 'e', 'r', 'E', 'f', 'g', 'H', 'i', 'I', 'j', 'k', 'l', 'm', 'n', 'G', 'O', 'Y', 'p', 'R', 's', 'S', '.', 't', 'T', 'u', 'U', 'v', 'W', '?', 'z', 'Z', '%']


PackedPhonModel(
  (rnn): LSTM(40, 512, num_layers=4, bidirectional=True)
  (scoring1): Linear(in_features=1024, out_features=1024, bias=True)
  (scoring2): Linear(in_features=1024, out_features=47, bias=True)
  (lsm): LogSoftmax()
)

In [13]:
# show string
decoder = CTCBeamDecoder(p_map, beam_width=100, blank_id=p_map.index('%'))

with torch.no_grad():
    cnt = 0
    for inputs, targets in val_loader:
        cnt += 1
        if cnt > 3:
            break
        output = model(inputs)
        sp = output.shape
        output = output.reshape((sp[0], sp[2]))
        output = output.cpu().numpy()
        
        # output of the network is N * 47, N is the length of input
#         print("output shape:", output.shape)
        
        # 1 * N * 47 list made from output
        probs_seq = torch.FloatTensor([output])
        
        output, _, _, out_seq_len = decoder.decode(probs_seq)
        
        # out_seq_len is a list of the decode output length of 100 results
#         print("out_seq_len shape", out_seq_len.shape)
#         print("out_seq_len content", out_seq_len)
        
        # output contains 100 decoded results, each decoded result is an integer list
        # each integer in the list is the index corresponding to p_map
#         print("output shape", output.shape)
#         print("first output content", output[0, 0])
        
        for i in range(output.size(0)):
            pred = "".join(p_map[o] for o in output[i, 0, :out_seq_len[i, 0]])
        
        true = "".join(p_map[o] for o in targets[0])

        dis = L.distance(pred, true)
        
        print("prediction string:")
        print(pred)
        print("")
        
        print("label string:")
        print(true)
        print("")
        
        print("distance:", dis)
        print("")

prediction string:
DnUzbRphhhhhhhhhhhhhhhhhhhhhhsssWHhHHssppp~p~~gWptRlADUHp~ddbbRm~WrhzRzZuu@@tbyu~EOtonAtovbd@+ph+t?~?ehTnWO?aa?Gpi__haYsYmsWatajijbiWWaWooYaEhg_hjcljSlczD__ip_nWOHbbhHooEuuUEIIntmkir~DpDi~Z~~hRDDh~enudtSSunHkz~rrsrtueAezp!wbgdkOipetD?p+bZS?SdYcDy._sfgw!I.SkfiiddDk++UAccvrm+yjm+wAkjApkkcToEfDfDDY+w!GTrn.EYv+fzka-jzjGaugvSbE+vfl+zds++SvyvH?_ZsbhfAAzvvEajnuRuGZZcpggfvm+fv+SkY-Yi__arusUwZvf_!ZUUkpmwlvwOwSwYYOYOuOuOuOumumumuOuOuOuOuOuOYgY?za

label string:
.DhfImElpRhdUsizhlitrhvtUtifoR?hGinnOvembrAnddisembr_.

distance: 410

prediction string:
ZASo_nphphhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhRRphhhbhhhhdh@??nuGUgIbWtWhAOAnn@hZW~.R.u~@WOEnEErrnZhk!zkUkhtvYuHOiiuunznsc~SeTZ.hU@W@AOIAb@ff~RDb~YW@OeHOOOOOHHhotOsSUAlzzgkjHceff.!+ddddaewRRRllHOOouy@lHWpdbuIurhdRzEijkZk~SvDjf@h@UUuoAGb~v@~meOIfj!!~DTjkgfrD~iiUhhulouuAt?eH~~ShUuOEuU@Tpp.@tru?OARrh.!!deeeediy@r@nlitdn@++.bbbbDYeylz+Dk+tfDy@@@YHOOlHw@@dTZZ.effidtrlIUO@nRDybh@@eheYn@@iHR?Usjdj~DvdmY~YEude~ng@ppphkiifkpkpRbRdddRRRRRRR