In [1]:
from .wrapper import CifRnntWrapper

In [2]:
device = "cuda:0"
model_filepath = "cif_rnnt/meanatt_3gram_nwords_ep40avg11.pt"
lang_dir = "cif_rnnt/lang_bpe_500"
one_file = "test_audio/1688-142285-0000.flac"
two_files = [
    "test_audio/1688-142285-0000.flac",
    "test_audio/1688-142285-0087.flac"
]
groundtruths = {
    #                                   "THERE'S IRON THEY SAY IN ALL OUR BLOOD AND A GRAIN OR TWO PERHAPS IS GOOD BUT HIS HE MAKES ME HARSHLY FEEL HAS GOT A LITTLE TOO MUCH OF STEEL ANON"]
    "test_audio/1688-142285-0000.flac": "THERE'S IRON THEY SAY IN ALL OUR BLOOD AND A GRAIN OR TWO PERHAPS IS GOOD BUT HIS HE MAKES ME HARSHLY FEEL HAS GOT A LITTLE TOO MUCH OF STEEL ANON",
    "test_audio/1688-142285-0087.flac" : "MISSUS THORNTON THE ONLY MOTHER HE HAS I BELIEVE SAID MISTER HALE QUIETLY"
}

In [3]:
decoder = CifRnntWrapper(
    model_filepath=model_filepath,
    lang_dir=lang_dir,
    device=device
)

  model_params = torch.load(model_filepath, map_location=device)


In [4]:
def sample_decode(
    *filenames : list[str],
):
    wavs = decoder.file_to_wav(*filenames)
    mels, mel_lens = decoder.wav_to_mel(wavs)
    awes, awe_lens = decoder.mel_to_awe(mels, mel_lens)
    hyps = decoder.awe_to_text(awes, awe_lens)
    
    refs = [
        groundtruths[f] for f in filenames
    ]
    
    alignment = decoder.awe_text_to_alignment(awes, awe_lens, refs)

    return {
        "wavs": wavs,
        "awes": awes,
        "awe_lens": awe_lens,
        "hyps": hyps,
        "alignment": alignment
    }

In [5]:
ret = sample_decode(one_file)
ret

{'wavs': [tensor([[0.0913, 0.0916, 0.0918,  ..., 0.0173, 0.0183, 0.0195]])],
 'awes': tensor([[[-0.6926, -0.2742, -0.1056,  ..., -0.0248,  0.0692, -0.0206],
          [-0.7274, -0.3990, -0.3083,  ...,  0.0058,  0.0608,  0.0386],
          [-0.3261, -0.0638, -0.7222,  ...,  0.0331, -0.0138, -0.0359],
          ...,
          [ 0.2202,  0.5394,  0.1839,  ..., -0.0113, -0.0025, -0.0076],
          [ 0.4680, -0.1542, -0.0192,  ..., -0.0406,  0.0166,  0.0328],
          [ 0.0058,  0.0570, -0.2888,  ...,  0.0277, -0.0632,  0.0543]]],
        device='cuda:0'),
 'awe_lens': tensor([29], device='cuda:0', dtype=torch.int32),
 'hyps': ["THERE'S IRON THEY SAY IN ALL OUR BLOOD AND A GRAIN OR TWO PERHAPS IS GOOD BUT HIS HE MAKES ME HARSHLY FEEL HAS GOT A LITTLE TOO MUCH OF STEEL ANON"],
 'awe2token': [[(0, ''),
   (1, "THERE'S"),
   (2, 'IRON'),
   (3, ''),
   (4, 'THEY SAY'),
   (5, 'IN ALL'),
   (6, 'OUR'),
   (7, 'BLOOD'),
   (8, 'AND'),
   (9, 'A GRAIN'),
   (10, ''),
   (11, 'OR TWO'),
   (12, 

In [6]:
ret = sample_decode(*two_files)
ret

{'wavs': [tensor([[0.0913, 0.0916, 0.0918,  ..., 0.0173, 0.0183, 0.0195]]),
  tensor([[-0.1065, -0.1060, -0.1055,  ...,  0.0111,  0.0101,  0.0099]])],
 'awes': tensor([[[-6.9262e-01, -2.7422e-01, -1.0559e-01,  ..., -2.4800e-02,
            6.9188e-02, -2.0588e-02],
          [-7.2738e-01, -3.9903e-01, -3.0826e-01,  ...,  5.8358e-03,
            6.0814e-02,  3.8594e-02],
          [-3.2611e-01, -6.3833e-02, -7.2217e-01,  ...,  3.3100e-02,
           -1.3813e-02, -3.5869e-02],
          ...,
          [ 2.2018e-01,  5.3935e-01,  1.8393e-01,  ..., -1.1320e-02,
           -2.4503e-03, -7.5547e-03],
          [ 4.6795e-01, -1.5418e-01, -1.9240e-02,  ..., -4.0595e-02,
            1.6587e-02,  3.2847e-02],
          [ 5.7840e-03,  5.6954e-02, -2.8882e-01,  ...,  2.7741e-02,
           -6.3177e-02,  5.4346e-02]],
 
         [[-5.6590e-01,  4.6722e-01, -2.9455e-02,  ..., -2.7147e-02,
            6.5345e-02, -2.3024e-02],
          [ 5.0038e-02,  7.8907e-02,  1.6104e-01,  ...,  2.9896e-02,
     

In [9]:
# You can convert a stacked AWE tensor into per-batch AWE.

from torch.nn.utils.rnn import unpad_sequence

unpadded_awes = unpad_sequence(ret["awes"], lengths=ret["awe_lens"], batch_first=True)
[s.size(0) for s in unpadded_awes], [len(s) for s in ret["alignment"]]

([29, 11], [29, 11])