In [None]:
import torchaudio, torch

In [None]:
ds = torchaudio.datasets.YESNO('.',  download=True, transform=torchaudio.transforms.Resample(orig_freq=8000, new_freq=16000))

In [None]:
end_to_end_SLU = '/home/gazay/code/jongleur/end-to-end-SLU'

import sys, re, os, torch
sys.path.append(end_to_end_SLU)
import data, models

class Model:    
    def __init__(self, name='no_unfreezing'):
        self.name = name
        self.experiment_path = end_to_end_SLU + '/experiments/' + name
        self.__prepare_config__()
        self.__prepare_dictionary__()
        self.__prepare_model__()

    def __prepare_config__(self):
        with open(self.experiment_path + '.cfg', 'r') as f:            
            content = f.read()
            content = re.sub(r'folder=.+', 'folder=' + self.experiment_path, content)
        with open(self.experiment_path + '.cfg', 'w') as f:
            f.write(content)
        self.config = data.read_config(self.experiment_path + '.cfg')
        self.config.Sy_intent = None
        self.config.values_per_slot = [24]
        
    def __prepare_model__(self):
        self.model = models.Model(self.config).eval()
        self.model.load_state_dict(torch.load(self.experiment_path + '/training/model_state.pth'))
        self.model.cuda()

    def __prepare_dictionary__(self):
        phonemes = []
        with open(os.path.join(self.config.folder, "pretraining", "phonemes.txt"), 'r') as f:
            for line in f.readlines():
                if line.rstrip("\n") != "": phonemes.append(line.rstrip("\n").lower())
        self.dictionary = phonemes
        self.config.num_phonemes = len(phonemes)

    def predict(self, wav_tensor):
        signal = wav_tensor.cuda().float().unsqueeze(0)
        logits = self.model.pretrained_model.compute_posteriors(signal)[0][0].data.cpu()
        phonemes = [m.dictionary[i] for i in logits.argmax(dim=1)]
        return phonemes
    

In [None]:
m = Model()

In [None]:
import editdistance

def encode_targets(targets):
    mapping = { 0: ['n','ow'], 1: ['y','eh','s'] } 
    return [x for target in targets for x in mapping[target.item()]]
def remove_repeats(seq):
    return [x for i, x in enumerate(seq) if x != 'sp' and (i == 0 or seq[i-1] != x)]

for wav_tensors,sample_rate,targets in torch.utils.data.DataLoader(ds):
    phns = m.predict(wav_tensors[0][0])
#     print(encode_targets(targets))
    print(remove_repeats(phns))
    encoded_targets = encode_targets(targets)
    print(editdistance.eval(encoded_targets, remove_repeats(phns))/len(encoded_targets))
