In [None]:
import sys
sys.path.append("../src/")
import pickle
import numpy as np
import torch
from IPython.display import Audio, display

from train import AudioTokenizer
from utils import LibriSpeechDataset

In [2]:
def remove_padding(z_ctx, z_ctx_quant, indices, utts_masks):
    """
    Trims frames corresponding to contextual padding
    """
    z_ctx1, z_ctx2 = z_ctx[0][~utts_masks[0].bool(),:], z_ctx[1][~utts_masks[1].bool(),:],
    z_ctx_quant1, z_ctx_quant2 = z_ctx_quant[0][~utts_masks[0].bool(),:], z_ctx_quant[1][~utts_masks[1].bool(),:]
    indices1, indices2 = indices[0][~utts_masks[0].bool()], indices[1][~utts_masks[1].bool()]
    return [z_ctx1, z_ctx2], [z_ctx_quant1, z_ctx_quant2], [indices1, indices2]

In [None]:
model_dir = "/home/anup/STD/checkpoints/BiMamba/LS-clean-360/codes1024_1s_LS360"
audio_path= "/DATA/datasets/anup_datasets/datasets/LibriSpeech" 
word_alignments_path= "/DATA/datasets/anup_datasets/datasets/LibriSpeech/LibriSpeech_alignment_word/words_alignment_train-clean-360.pkl"
cfg = pickle.load(open(model_dir+"/params.pkl", "rb"))

dataset = LibriSpeechDataset(alignments_path=word_alignments_path,
                                audio_dir=audio_path,
                                feats_type=cfg['feats_type'],
                                n_mels=cfg['n_mels'],
                                fs=cfg['fs'],
                                rf_size=cfg['win_size'],
                                stride=cfg['win_hop'], 
                                inp_len=cfg['seglen'],
                                add_augment=False,
                                noise_dir_path=None,
                                snr_range=None)

model = AudioTokenizer.load_from_pretrained(checkpoint=model_dir+"/bsz:128_lr:0.0005_seg:1.0/checkpoints/epoch=539-valid_loss=0.00-train_loss=0.00.ckpt",
                                            gpu_index=0)

In [None]:
# retrieve pair of utterances of a term 
sample_idx = np.random.choice(len(dataset))
sample_dict = dataset.__getitem__(sample_idx)
fixed_len_utts_spec = sample_dict['utterances']
utts_masks = sample_dict['utterance_idx']
utts_len = np.array(sample_dict['utterance_len'])/cfg['fs']
utt_word = sample_dict['word']

print(f"spoken term: {utt_word}")
# pair of utterances (with added contextual padding) of the spoken term. 
utt1 = Audio(sample_dict["utterances_raw"][0], rate=cfg["fs"])
utt2 = Audio(sample_dict["utterances_raw"][1], rate=cfg["fs"])
display(utt1)
display(utt2)

In [5]:
# get tokenized representation
batch = torch.cat((fixed_len_utts_spec[0].unsqueeze(0), fixed_len_utts_spec[1].unsqueeze(0)), dim=0)
z_ctx, z_ctx_quant, indices = model.predict_step(batch.cuda(), 1)
_, _, indices  = remove_padding(z_ctx, z_ctx_quant, indices, list(utts_masks)) 

In [None]:
print(f"utterance 1 tokens: {indices[0]}")
print(f"utterance 2 tokens: {indices[1]}")