In [72]:
import gc
import itertools as it
import os.path as osp
import warnings
from collections import deque, namedtuple

from fairseq.data import Dictionary
import numpy as np
import torch
from examples.speech_recognition.data.replabels import unpack_replabels
from fairseq import tasks
from fairseq.utils import apply_to_sample
from omegaconf import open_dict
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
import torch
from tqdm import tqdm

In [2]:


from wav2letter.common import create_word_dict, load_words
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
from wav2letter.decoder import (
    CriterionType,
    DecoderOptions,
    KenLM,
    LM,
    LMState,
    SmearingMode,
    Trie,
    LexiconDecoder,
)



In [3]:
class W2lDecoder(object):
    def __init__(self, args, tgt_dict):
        self.tgt_dict = tgt_dict
        self.vocab_size = len(tgt_dict)
        self.nbest = args['nbest']

        # criterion-specific init
        if args['criterion'] == "ctc":
            self.criterion_type = CriterionType.CTC
            self.blank = (
                tgt_dict.index("<ctc_blank>")
                if "<ctc_blank>" in tgt_dict.indices
                else tgt_dict.bos()
            )
            self.asg_transitions = None
        elif args.criterion == "asg_loss":
            self.criterion_type = CriterionType.ASG
            self.blank = -1
            self.asg_transitions = args.asg_transitions
            self.max_replabel = args.max_replabel
            assert len(self.asg_transitions) == self.vocab_size ** 2
        else:
            raise RuntimeError(f"unknown criterion: {args.criterion}")

    def generate(self, models, sample, **unused):
        """Generate a batch of inferences."""
        # model.forward normally channels prev_output_tokens into the decoder
        # separately, but SequenceGenerator directly calls model.encoder
        encoder_input = {
            k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
        }
        emissions = self.get_emissions(models, encoder_input)
        return self.decode(emissions)

    def get_emissions(self, models, encoder_input):
        """Run encoder and normalize emissions"""
        # encoder_out = models[0].encoder(**encoder_input)
        encoder_out = models[0](**encoder_input)
        if self.criterion_type == CriterionType.CTC:
            emissions = models[0].get_normalized_probs(encoder_out, log_probs=True)
        elif self.criterion_type == CriterionType.ASG:
            emissions = encoder_out["encoder_out"]
        return emissions.transpose(0, 1).float().cpu().contiguous()

    def get_tokens(self, idxs):
        """Normalize tokens by handling CTC blank, ASG replabels, etc."""
        idxs = (g[0] for g in it.groupby(idxs))
        if self.criterion_type == CriterionType.CTC:
            idxs = filter(lambda x: x != self.blank, idxs)
        elif self.criterion_type == CriterionType.ASG:
            idxs = filter(lambda x: x >= 0, idxs)
            idxs = unpack_replabels(list(idxs), self.tgt_dict, self.max_replabel)
        return torch.LongTensor(list(idxs))



In [4]:
class W2lKenLMDecoder(W2lDecoder):
    def __init__(self, args, tgt_dict):
        super().__init__(args, tgt_dict)

        self.silence = (
            tgt_dict.index("<ctc_blank>")
            if "<ctc_blank>" in tgt_dict.indices
            else tgt_dict.bos()
        )
        self.lexicon = load_words(args['lexicon'])
        self.word_dict = create_word_dict(self.lexicon)
        self.unk_word = self.word_dict.get_index("<unk>")

        self.lm = KenLM(args['kenlm_model'], self.word_dict)
        self.trie = Trie(self.vocab_size, self.silence)

        start_state = self.lm.start(False)
        for i, (word, spellings) in enumerate(self.lexicon.items()):
            word_idx = self.word_dict.get_index(word)
            _, score = self.lm.score(start_state, word_idx)
            for spelling in spellings:
                spelling_idxs = [tgt_dict.index(token) for token in spelling]
                assert (
                    tgt_dict.unk() not in spelling_idxs
                ), f"{spelling} {spelling_idxs}"
                self.trie.insert(spelling_idxs, word_idx, score)
        self.trie.smear(SmearingMode.MAX)

        self.decoder_opts = DecoderOptions(
            args['beam'],
            int(getattr(args, "beam_size_token", len(tgt_dict))),
            args['beam_threshold'],
            args['lm_weight'],
            args['word_score'],
            args['unk_weight'],
            args['sil_weight'],
            0,
            False,
            self.criterion_type,
        )

        if self.asg_transitions is None:
            N = 768
            # self.asg_transitions = torch.FloatTensor(N, N).zero_()
            self.asg_transitions = []

        self.decoder = LexiconDecoder(
            self.decoder_opts,
            self.trie,
            self.lm,
            self.silence,
            self.blank,
            self.unk_word,
            self.asg_transitions,
            False,
        )

    def decode(self, emissions):
        B, T, N = emissions.size()
        hypos = []
        for b in range(B):
            emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
            results = self.decoder.decode(emissions_ptr, T, N)

            nbest_results = results[: self.nbest]
            hypos.append(
                [
                    {
                        "tokens": self.get_tokens(result.tokens),
                        "score": result.score,
                        "words": [
                            self.word_dict.get_entry(x) for x in result.words if x >= 0
                        ],
                    }
                    for result in nbest_results
                ]
            )
        return hypos



In [5]:
class W2lViterbiDecoder(W2lDecoder):
    def __init__(self, args, tgt_dict):
        super().__init__(args, tgt_dict)

    def decode(self, emissions):
        B, T, N = emissions.size()
        hypos = []
        if self.asg_transitions is None:
            transitions = torch.FloatTensor(N, N).zero_()
        else:
            transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
        viterbi_path = torch.IntTensor(B, T)
        workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
        CpuViterbiPath.compute(
            B,
            T,
            N,
            get_data_ptr_as_bytes(emissions),
            get_data_ptr_as_bytes(transitions),
            get_data_ptr_as_bytes(viterbi_path),
            get_data_ptr_as_bytes(workspace),
        )
        return [
            [{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
            for b in range(B)
        ]


In [6]:
ls 

Untitled.ipynb


In [79]:
def get_decoder(decoder=None,BEAM=128,LM_WEIGHT=2,WORD_SCORE=-1):
  
    args_lm = {}
    args_lm['lexicon'] = '/home/harveen.chadha/_/experiments/experiment_5/english-asr-challenge/data/training/lexicon.lst'
    args_lm['kenlm_model'] = '/home/harveen.chadha/_/english-asr-challenge/lm/lm.binary'
    args_lm['beam'] = BEAM
    args_lm['beam_threshold'] = 128
    args_lm['lm_weight'] = LM_WEIGHT
    args_lm['word_score'] = WORD_SCORE
    args_lm['unk_weight'] = -np.inf
    args_lm['sil_weight'] = 0
    args_lm['nbest'] = 1
    args_lm['criterion'] ='ctc'
    args_lm['labels']='ltr'
    
    if decoder=="kenlm":
        target_dict = Dictionary.load("/home/harveen.chadha/_/experiments/experiment_5/english-asr-challenge/data/training/dict.ltr.txt")
        k = W2lKenLMDecoder(args_lm, target_dict)
    else:
        target_dict = Dictionary.load("/home/harveen.chadha/_/experiments/experiment_5/english-asr-challenge/data/training/dict.ltr.txt")
        k=W2lViterbiDecoder(args_lm, target_dict)
    return k

In [80]:
def post_process(sentence: str, symbol: str):
    if symbol == "sentencepiece":
        sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
    elif symbol == 'wordpiece':
        sentence = sentence.replace(" ", "").replace("_", " ").strip()
    elif symbol == 'letter':
        sentence = sentence.replace(" ", "").replace("|", " ").strip()
    elif symbol == "_EOW":
        sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
    elif symbol is not None and symbol != 'none':
        sentence = (sentence + " ").replace(symbol, "").rstrip()
    return sentence


# def get(i):
#     target_dict = Dictionary.load("'/home/harveen.chadha/_/experiments/experiment_5/english-asr-challenge/data/training/dict.ltr.txt'")
#     hyp_pieces = target_dict.string(get_decoder().decode(i)[0][0]["tokens"].int().cpu())
#     text=post_process(hyp_pieces, 'letter')
#     return text

In [81]:
# kenlm_decoder = get_decoder(decoder='kenlm')

In [82]:
def get(i,LM_WEIGHT,WORD_SCORE):
  return " ".join(get_decoder("kenlm",LM_WEIGHT=LM_WEIGHT,WORD_SCORE=WORD_SCORE).decode(i)[0][0]["words"])

In [83]:
emmissions = np.load('/home/harveen.chadha/_/experiments/experiment_5/english-asr-challenge/scripts/inference/dev_IITM.npy', allow_pickle=True)


In [84]:
emmissions = [torch.from_numpy(e) for e in emmissions]

In [85]:
emmissions = [e.unsqueeze(0) for e in emmissions]

In [86]:
lm_weight = [1,3.1]
word_score = [-2,0.1]

In [87]:
import numpy as np

In [88]:
lm_weight_range = np.arange(lm_weight[0], lm_weight[1], 0.1)
word_score_range = np.arange(word_score[0], word_score[1], 0.1)

In [89]:
combinations = []

for i in lm_weight_range:
    for j in word_score_range:
        combinations.append([i,j])
combinations.reverse()

In [122]:
combinations

[[3.0000000000000018, 1.7763568394002505e-15],
 [3.0000000000000018, -0.09999999999999831],
 [3.0000000000000018, -0.1999999999999984],
 [3.0000000000000018, -0.2999999999999985],
 [3.0000000000000018, -0.3999999999999986],
 [3.0000000000000018, -0.49999999999999867],
 [3.0000000000000018, -0.5999999999999988],
 [3.0000000000000018, -0.6999999999999988],
 [3.0000000000000018, -0.7999999999999989],
 [3.0000000000000018, -0.899999999999999],
 [3.0000000000000018, -0.9999999999999991],
 [3.0000000000000018, -1.0999999999999992],
 [3.0000000000000018, -1.1999999999999993],
 [3.0000000000000018, -1.2999999999999994],
 [3.0000000000000018, -1.3999999999999995],
 [3.0000000000000018, -1.4999999999999996],
 [3.0000000000000018, -1.5999999999999996],
 [3.0000000000000018, -1.6999999999999997],
 [3.0000000000000018, -1.7999999999999998],
 [3.0000000000000018, -1.9],
 [3.0000000000000018, -2.0],
 [2.9000000000000017, 1.7763568394002505e-15],
 [2.9000000000000017, -0.09999999999999831],
 [2.900000

In [94]:
valid_iitm = []
with open('/home/harveen.chadha/_/experiments/experiment_5/english-asr-challenge/data/dev/dev_IITM/valid.wrd') as file:
    valid_iitm = file.readlines()

In [98]:
list_results = []
for comb in combinations:
    for index, local_emmission in enumerate(emmissions):
        predicted = get(local_emmission, comb[0], comb[1])
        ground_truth = valid_iitm[index]
        
        list_results.append([comb[0], comb[1], predicted, ground_truth])

KeyboardInterrupt: 

In [124]:
import pandas as pd

In [125]:
pd.DataFrame(combinations)

Unnamed: 0,0,1
0,3.0,1.776357e-15
1,3.0,-1.000000e-01
2,3.0,-2.000000e-01
3,3.0,-3.000000e-01
4,3.0,-4.000000e-01
...,...,...
436,1.0,-1.600000e+00
437,1.0,-1.700000e+00
438,1.0,-1.800000e+00
439,1.0,-1.900000e+00


In [120]:
from joblib import Parallel, delayed
from joblib import wrap_non_picklable_objects
from joblib.externals.loky import set_loky_pickler
from joblib import parallel_backend

In [121]:
list_results = []


def get_single_prediction(index, local_emission, combination):
    predicted = get(local_emmission, comb[0], comb[1])
    ground_truth = valid_iitm[index]    
    list_results.append([comb[0], comb[1], predicted, ground_truth])

@delayed
@wrap_non_picklable_objects
def get_predictions(comb):
    #return Parallel(n_jobs=-1)(delayed(get_single_prediction)(index, local_emmission, combination) for index, local_emmission in tqdm(enumerate(emmissions))) 
    for index, local_emmission in tqdm(enumerate(emmissions)):
        return get(local_emmission, comb[0], comb[1])
#         ground_truth = valid_iitm[index]    
#         return 1
#         list_results.append([comb[0], comb[1], predicted, ground_truth])
    return 1

with parallel_backend('multiprocessing'):
    Parallel(n_jobs=-1)(delayed(get_predictions)(comb) for comb in tqdm(combinations))



 14%|█▍        | 63/441 [05:05<30:34,  4.85s/it]

  0%|          | 1/441 [00:00<01:54,  3.85it/s][A

PicklingError: Can't pickle <function get_predictions at 0x7f16d85e85e0>: attribute lookup get_predictions on joblib.externals.loky.cloudpickle_wrapper failed

In [27]:
ls '/home/harveen.chadha/_/experiments/experiment_5/english-asr-challenge/data/training/dict.ltr.txt'

/home/harveen.chadha/_/experiments/experiment_5/english-asr-challenge/data/training/dict.ltr.txt


In [178]:
# for i in emmissions:
#     print(get(i, 1.0, -0.1))
#     #kenlm_decoder.decode(i)

KeyError: 'words'