In [1]:
from typing import Iterable, Tuple, List
import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn.functional as F
import sentencepiece
import omegaconf
import pytorch_lightning as pl
import pandas as pd
from tqdm.auto import tqdm
import json

from src.models import ConformerLAS, ConformerCTC
from src.metrics import WER

In [642]:
def init_model(model: pl.LightningModule, ckpt_path: str) -> pl.LightningModule:
    ckpt = torch.load(ckpt_path, map_location="cpu")
    model.load_state_dict(ckpt)
    model.eval()
    model.freeze()
    return model


def compute_wer(refs: Iterable[str], hyps: Iterable[str]) -> float:
    wer = WER()
    wer.update(refs, hyps)
    return wer.compute()[0].item()

class GreedyDecoderLAS:
    def __init__(self, model: ConformerLAS, max_steps=20):
        self.model = model
        self.max_steps = max_steps

    def __call__(self, encoded: torch.Tensor) -> str:
        
        tokens = [self.model.decoder.tokenizer.bos_id()]

        for _ in range(self.max_steps):
            
            tokens_batch = torch.tensor(tokens).unsqueeze(0)
            att_mask = self.model.make_attention_mask(torch.tensor([tokens_batch.size(-1)]))
            
            distribution = self.model.decoder(
                encoded=encoded, encoded_pad_mask=None,
                target=tokens_batch, target_mask=att_mask, target_pad_mask=None
            )
        
            best_next_token = distribution[0, -1].argmax()
            
            if best_next_token == self.model.decoder.tokenizer.eos_id():
                break

            tokens.append(best_next_token.item())
        
        return self.model.decoder.tokenizer.decode(tokens)

# Single Model

In [16]:
dataset = 'test_opus/farfield/farfield.jsonl'

## LAS

In [620]:
conf = omegaconf.OmegaConf.load("./conf/conformer_las.yaml")
conf.val_dataloader.dataset.manifest_name = dataset
conf.model.decoder.tokenizer = "./data/tokenizer/bpe_1024_bos_eos.model"

conformer_las = init_model(
    model=ConformerLAS(conf=conf),
    ckpt_path="./data/conformer_las_2epochs.ckpt"
)

In [621]:
las_decoder = GreedyDecoderLAS(conformer_las)

refs, hyps_las = [], []

for batch in tqdm(conformer_las.val_dataloader()):

    features, features_len, targets, target_len = batch

    encoded, encoded_len = conformer_las(features, features_len)
    
    for i in range(features.shape[0]):

        encoder_states = encoded[
            [i],
            :encoded_len[i],
            :
        ]

        ref_tokens = targets[i, :target_len[i]].tolist()

        refs.append(
            conformer_las.decoder.tokenizer.decode(ref_tokens)
        )
        hyps_las.append(
            las_decoder(encoder_states)
        )

MANIFEST_PATH:  /Users/alexandrmaximenko/mipt_tasks/speech-tech-mipt/week09/asr/data/test_opus/farfield/farfield.jsonl


  0%|          | 0/479 [00:01<?, ?it/s]

In [622]:
wer_las = compute_wer(refs, hyps_las)
print(wer_las)

0.42290276288986206


## CTC

In [None]:
# TODO: load models, estimate WER

In [623]:
def decode_ctc_hyps(model: ConformerCTC) -> Tuple[List[str], List[str]]:
    refs, hyps_ctc = [], []

    for batch in tqdm(model.val_dataloader()):

        features, features_len, targets, target_len = batch

        logprobs, encoded_len, preds = model(features, features_len)
        
        refs_curr = model.decoder.decode(token_ids=targets, token_ids_length=target_len)
        
        hyps_curr = model.decoder.decode(
            token_ids=preds, token_ids_length=encoded_len, unique_consecutive=True
        )
        refs.extend(refs_curr)
        hyps_ctc.extend(hyps_curr)
        
    return refs, hyps_ctc

In [624]:
conf = omegaconf.OmegaConf.load("./conf/conformer_ctc.yaml")
conf.val_dataloader.dataset.manifest_name = dataset

conformer_ctc = init_model(
    model=ConformerCTC(conf=conf),
    ckpt_path="./data/conformer_7epochs_state_dict.ckpt"
)

refs, hyps_ctc = decode_ctc_hyps(conformer_ctc)

  0%|          | 0/479 [00:01<?, ?it/s]

In [625]:
wer_ctc = compute_wer(refs, hyps_ctc)
print(wer_ctc)

0.4251968562602997


In [627]:
conf = omegaconf.OmegaConf.load("./conf/conformer_ctc_wide.yaml")
conf.val_dataloader.dataset.manifest_name = dataset

conformer_ctc_wide = init_model(
    model=ConformerCTC(conf=conf),
    ckpt_path="./data/conformer_wide_7epochs_state_dict.ckpt"
)

refs, hyps_ctc_wide = decode_ctc_hyps(conformer_ctc_wide)

  0%|          | 0/479 [00:01<?, ?it/s]

In [629]:
wer_ctc_wide = compute_wer(refs, hyps_ctc_wide)
print(wer_ctc_wide)

0.369840145111084


# ROVER: Recognizer Output Voting Error Reduction — 5 points

* [A post-processing system to yield reduced word error rates: Recognizer Output Voting Error Reduction (ROVER)](https://ieeexplore.ieee.org/document/659110)
* [Improved ROVER using Language Model Information](https://www-tlp.limsi.fr/public/asr00_holger.pdf)

Alignment + Voting

![](./images/rover_table.png)

In [630]:
from crowdkit.aggregation.texts import ROVER

In [631]:
df_for_rover = pd.DataFrame(columns=['task', 'text', 'worker'])

In [632]:
for i, ref in enumerate(refs):
    row_ctc_wide = {
        'task': i,
        'text': hyps_ctc_wide[i],
        'worker': 'ctc_wide'
    }
    row_ctc = {
        'task': i,
        'text': hyps_ctc[i],
        'worker': 'ctc'
    }
    row_las = {
        'task': i,
        'text': hyps_las[i],
        'worker': 'las'
    }
    df_for_rover.loc[i*3 ] = pd.Series(row_ctc_wide)
    df_for_rover.loc[i*3 + 1] = pd.Series(row_ctc)
    df_for_rover.loc[i*3 + 2] = pd.Series(row_las)


Добавим кривоватое взвешивание в ROVER из crowdkit.

In [633]:
from copy import deepcopy
from enum import Enum, unique
from typing import List, Callable, Dict, Optional, Tuple, cast

import attr
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from crowdkit.aggregation.base import BaseTextsAggregator


@unique
class AlignmentAction(Enum):
    DELETION = 'DELETION'
    SUBSTITUTION = 'SUBSTITUTION'
    INSERTION = 'INSERTION'
    CORRECT = 'CORRECT'


@attr.s
class AlignmentEdge:
    value: str = attr.ib()
    sources_count: Optional[int] = attr.ib()


@attr.s
class ROVER(BaseTextsAggregator):
    """Recognizer Output Voting Error Reduction (ROVER).
    This method uses dynamic programming to align sequences. Next, aligned sequences are used
    to construct the Word Transition Network (WTN):
    ![ROVER WTN scheme](https://tlk.s3.yandex.net/crowd-kit/docs/rover.png)
    Finally, the aggregated sequence is the result of majority voting on each edge of the WTN.
    J. G. Fiscus,
    "A post-processing system to yield reduced word error rates: Recognizer Output Voting Error Reduction (ROVER),"
    *1997 IEEE Workshop on Automatic Speech Recognition and Understanding Proceedings*, 1997, pp. 347-354.
    <https://doi.org/10.1109/ASRU.1997.659110>
    Args:
        tokenizer: A callable that takes a string and returns a list of tokens.
        detokenizer: A callable that takes a list of tokens and returns a string.
        silent: If false, show a progress bar.
    Examples:
        >>> from crowdkit.aggregation import load_dataset
        >>> from crowdkit.aggregation import ROVER
        >>> df, gt = load_dataset('crowdspeech-test-clean')
        >>> df['text'] = df['text'].apply(lambda s: s.lower())
        >>> tokenizer = lambda s: s.split(' ')
        >>> detokenizer = lambda tokens: ' '.join(tokens)
        >>> result = ROVER(tokenizer, detokenizer).fit_predict(df)
    Attributes:
        texts_ (Series): Tasks' texts.
            A pandas.Series indexed by `task` such that `result.loc[task, text]`
            is the task's text.
    """

    tokenizer: Callable[[str], List[str]] = attr.ib()
    detokenizer: Callable[[List[str]], str] = attr.ib()
    silent: bool = attr.ib(default=True)

    # Available after fit
    # texts_

    def fit(self, data: pd.DataFrame, weights: List[float]) -> 'ROVER':
        """Fits the model. The aggregated results are saved to the `texts_` attribute.
        Args:
            data (DataFrame): Workers' text outputs.
                A pandas.DataFrame containing `task`, `worker` and `text` columns.
        Returns:
            ROVER: self.
        """

        result = {}
        grouped_tasks = data.groupby('task') if self.silent else tqdm(data.groupby('task'))
        for task, df in grouped_tasks:
            hypotheses = [self.tokenizer(text) for i, text in enumerate(df['text'])]

            edges = self._build_word_transition_network(hypotheses, weights)
            rover_result = self._get_result(edges)

            text = self.detokenizer([value for value in rover_result if value != ''])

            result[task] = text

        texts = pd.Series(result, name='text')
        texts.index.name = 'task'
        self.texts_ = texts

        return self

    def fit_predict(self, data: pd.DataFrame, weights) -> pd.Series:
        """Fit the model and return the aggregated texts.
        Args:
            data (DataFrame): Workers' text outputs.
                A pandas.DataFrame containing `task`, `worker` and `text` columns.
        Returns:
            Series: Tasks' texts.
                A pandas.Series indexed by `task` such that `result.loc[task, text]`
                is the task's text.
        """

        self.fit(data, weights)
        return self.texts_

    def _build_word_transition_network(self, hypotheses: List[List[str]], weights: List[float]) -> List[Dict[str, AlignmentEdge]]:
        edges = [{edge.value: edge} for edge in self._get_edges_for_words(hypotheses[0], weights[0])]

        for sources_count, hyp in enumerate(hypotheses[1:], start=1):
            edges = self._align(edges, self._get_edges_for_words(hyp, weights[sources_count]), sources_count, weights[sources_count])

        return edges

    @staticmethod
    def _get_edges_for_words(words: List[str], weight: float) -> List[AlignmentEdge]:
        return [AlignmentEdge(word, weight) for word in words]

    @staticmethod
    def _align(
            ref_edges_sets: List[Dict[str, AlignmentEdge]],
            hyp_edges: List[AlignmentEdge],
            sources_count: int,
            weight: float
    ) -> List[Dict[str, AlignmentEdge]]:
        """Sequence alignment algorithm implementation.
        Aligns a sequence of sets of tokens (edges) with a sequence of tokens using dynamic programming algorithm. Look
        for section 2.1 in <https://doi.org/10.1109/ASRU.1997.659110> for implementation details. Penalty for
        insert/deletion or mismatch is 1.
        Args:
           ref_edges_sets: Sequence of sets formed from previously aligned sequences.
           hyp_edges: Tokens from hypothesis (currently aligned) sequence.
           sources_count: Number of previously aligned sequences.
        """

        distance = np.zeros((len(hyp_edges) + 1, len(ref_edges_sets) + 1))
        distance[:, 0] = np.arange(len(hyp_edges) + 1)
        distance[0, :] = np.arange(len(ref_edges_sets) + 1)

        memoization: List[List[Optional[Tuple[AlignmentAction, Dict[str, AlignmentEdge], AlignmentEdge]]]] = [
            [None] * (len(ref_edges_sets) + 1) for _ in range(len(hyp_edges) + 1)
        ]

        for i, hyp_edge in enumerate(hyp_edges, start=1):
            memoization[i][0] = (
                AlignmentAction.INSERTION,
                {'': AlignmentEdge('', sources_count)},
                hyp_edge,
            )
        for i, ref_edges in enumerate(ref_edges_sets, start=1):
            memoization[0][i] = (
                AlignmentAction.DELETION,
                ref_edges,
                AlignmentEdge('', 1),
            )

        # find alignment minimal cost using dynamic programming algorithm
        for i, hyp_edge in enumerate(hyp_edges, start=1):
            hyp_word = hyp_edge and hyp_edge.value
            for j, ref_edges in enumerate(ref_edges_sets, start=1):
                ref_words_set = ref_edges.keys()
                is_hyp_word_in_ref = hyp_word in ref_words_set

                options = []

                if is_hyp_word_in_ref:
                    options.append((
                        distance[i - 1, j - 1],
                        (
                            AlignmentAction.CORRECT,
                            ref_edges,
                            hyp_edge,
                        )
                    ))
                else:
                    options.append((
                        distance[i - 1, j - 1] + 1,
                        (
                            AlignmentAction.SUBSTITUTION,
                            ref_edges,
                            hyp_edge,
                        )
                    ))
                options.append((
                    distance[i, j - 1] + ('' not in ref_edges),
                    (
                        AlignmentAction.DELETION,
                        ref_edges,
                        AlignmentEdge('', 1),
                    )
                ))
                options.append((
                    distance[i - 1, j] + 1,
                    (
                        AlignmentAction.INSERTION,
                        {'': AlignmentEdge('', sources_count)},
                        hyp_edge,
                    )
                ))

                distance[i, j], memoization[i][j] = min(options, key=lambda t: t[0])  # type: ignore

        alignment = []
        i = len(hyp_edges)
        j = len(ref_edges_sets)

        # reconstruct answer from dp array
        while i != 0 or j != 0:
            action, ref_edges, hyp_edge = cast(Tuple[AlignmentAction, Dict[str, AlignmentEdge], AlignmentEdge],
                                               memoization[i][j])
            joined_edges = deepcopy(ref_edges)
            hyp_edge_word = hyp_edge.value
            if hyp_edge_word not in joined_edges:
                joined_edges[hyp_edge_word] = hyp_edge
            else:
                # if word is already in set increment sources count for future score calculation
                joined_edges[hyp_edge_word].sources_count += weight  # type: ignore
            alignment.append(joined_edges)
            if action == AlignmentAction.CORRECT or action == AlignmentAction.SUBSTITUTION:
                i -= 1
                j -= 1
            elif action == AlignmentAction.INSERTION:
                i -= 1
            # action == AlignmentAction.DELETION
            else:
                j -= 1

        return alignment[::-1]

    @staticmethod
    def _get_result(edges: List[Dict[str, AlignmentEdge]]) -> List[str]:
        result = []
        for edges_set in edges:
            _, _, value = max((x.sources_count, len(x.value), x.value) for x in edges_set.values())
            result.append(value)
        return result

Сначала саггрегируем с одинаковыми весами

In [635]:
# TODO: aggregate hypotheses, estimate WER
tokenizer = lambda s: s.split(' ')
detokenizer = lambda tokens: ' '.join(tokens)
rover_crowdkit = ROVER(tokenizer, detokenizer)
results = rover_crowdkit.fit_predict(df_for_rover, weights=[1, 1, 1])
compute_wer(refs, results)

0.35707467794418335

Теперь чуть поменяем веса в кривом взвешенном Rover-е и станет чуть лучше.

In [637]:
results = rover_crowdkit.fit_predict(df_for_rover, weights=[1.5, 1, 1.5])
compute_wer(refs, results)

0.347053200006485

# MBR: Minimum Bayes Risk — 5 points


* [Minimum Bayes Risk Decoding and System
Combination Based on a Recursion for Edit Distance](https://danielpovey.com/files/csl11_consensus.pdf)
* [mbr-decoding blog-post](https://suzyahyah.github.io/bayesian%20inference/machine%20translation/2022/02/15/mbr-decoding.html)
* [Combination of end-to-end and hybrid models for speech recognition](http://www.interspeech2020.org/uploadfile/pdf/Tue-1-8-4.pdf)

![](./images/mbr_scheme.png)

In [643]:
# Добавим нужную функцию get_hyp_probability в декодер

class GreedyDecoderLASv1(GreedyDecoderLAS):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def get_hyp_probability(self, encoded: torch.Tensor, hyp):
        tokens = [self.model.decoder.tokenizer.bos_id()]
        score = 1
        softmax = nn.Softmax()
        for i in range(1, len(hyp)):
            tokens_batch = torch.tensor(tokens).unsqueeze(0)
            att_mask = self.model.make_attention_mask(torch.tensor([tokens_batch.size(-1)]))
            
            distribution = self.model.decoder(
                encoded=encoded, encoded_pad_mask=None,
                target=tokens_batch, target_mask=att_mask, target_pad_mask=None
            )
            
            score *= softmax(distribution[0, -1])[hyp[i]]
            tokens.append(hyp[i])

        return score

In [644]:
las_decoder = GreedyDecoderLASv1(conformer_las)
def get_hyp_prob_las(hyp, encoder_states, las_decoder=las_decoder, conformer_las=conformer_las):
    hyp_tokens = [conformer_las.decoder.tokenizer.bos_id()]
    hyp_tokens.extend(conformer_las.decoder.tokenizer.tokenize(hyp))
    hyp_tokens.append(conformer_las.decoder.tokenizer.eos_id())
    return las_decoder.get_hyp_probability(encoder_states, hyp_tokens)

In [645]:
labels = [ " ", "а", "б", "в", "г", "д", "е", "ж", "з", "и", "й", "к", "л", "м", "н", "о", "п", "р", "с", "т", "у", "ф", "х", "ц", "ч", "ш", "щ", "ъ", "ы", "ь", "э", "ю", "я" ]
token_to_idx = {token: idx for idx, token in enumerate(labels)}

def get_hyp_prob_ctc(hyp, model, logprobs, encoded_len):
    if '⁇' in hyp:
        return None
    hyp_tokens = torch.Tensor([int(token_to_idx[token]) for token in hyp]).long()
    target_len = len(hyp)
    loss = model.ctc_loss(
        logprobs, hyp_tokens, torch.Tensor(encoded_len), torch.tensor(target_len)
    )
    prob = torch.exp(-loss)
    return prob

In [650]:
import editdistance
def get_distance_matrix(hyps: List[str]):
    distance_matrix = np.zeros((len(hyps), len(hyps)))
    for i in range(len(hyps)):
        for j in range(i + 1, len(hyps)):
            distance = editdistance.eval(hyps[i], hyps[j])
            distance_matrix[i, j] = distance
            distance_matrix[j, i] = distance
    return distance_matrix

def get_mbr_hyp(hyps, scores, weights=None):
    if weights is None:
        weights = np.ones(len(hyps))
    hyp_scores = []
    distance_matrix = get_distance_matrix(hyps)
    for i in range(len(hyps)):
        tmp_scores = []
        for j in range(len(hyps)):
            probs_sum = np.sum(scores[j])
            tmp_score = np.dot(distance_matrix[i], scores[j]) / probs_sum
            tmp_scores.append(tmp_score)
        hyp_scores.append(np.dot(weights, tmp_scores))
    return hyps[np.argmin(hyp_scores)]


Посчитаем скоры моделей для использвания в MBR.

In [647]:
batch_size = 4
scores_las = []
scores_ctc = []
scores_ctc_wide = []
k = -1

for i, batch in tqdm(enumerate(conformer_las.val_dataloader()), total=len(conformer_las.val_dataloader())):
    features, features_len, targets, target_len = batch
    curr_hyp_las = hyps_las[batch_size * i: batch_size * i + batch_size]
    curr_hyp_ctc = hyps_ctc[batch_size * i: batch_size * i + batch_size]
    curr_hyp_ctc_wide = hyps_ctc_wide[batch_size * i: batch_size * i + batch_size]
    
    encoded_las, encoded_len = conformer_las(features, features_len)
    logprobs_ctc, encoded_len, preds_ctc = conformer_ctc(features, features_len)
    logprobs_ctc_wide, encoded_len, preds_ctc_wide = conformer_ctc_wide(features, features_len)

    for i in range(features.shape[0]):
        curr_scores_las = []
        curr_scores_ctc = []
        curr_scores_ctc_wide = []

        encoder_states_las = encoded_las[
            [i],
            :encoded_len[i],
            :
        ]
        
        curr_scores_las.append(get_hyp_prob_las(curr_hyp_las[i], encoder_states_las))
        curr_scores_las.append(get_hyp_prob_las(curr_hyp_ctc[i], encoder_states_las))
        curr_scores_las.append(get_hyp_prob_las(curr_hyp_ctc_wide[i], encoder_states_las))
        
        curr_scores_ctc.append(
            get_hyp_prob_ctc(curr_hyp_las[i], conformer_ctc, logprobs_ctc[i], encoded_len[i])
        )
        curr_scores_ctc.append(
            get_hyp_prob_ctc(curr_hyp_ctc[i], conformer_ctc, logprobs_ctc[i], encoded_len[i])
        )
        curr_scores_ctc.append(
            get_hyp_prob_ctc(curr_hyp_ctc_wide[i], conformer_ctc, logprobs_ctc[i], encoded_len[i])
        )
        
        curr_scores_ctc_wide.append(
            get_hyp_prob_ctc(curr_hyp_las[i], conformer_ctc_wide, logprobs_ctc_wide[i], encoded_len[i])
        )
        curr_scores_ctc_wide.append(
            get_hyp_prob_ctc(curr_hyp_ctc[i], conformer_ctc_wide, logprobs_ctc_wide[i], encoded_len[i])
        )
        curr_scores_ctc_wide.append(
            get_hyp_prob_ctc(curr_hyp_ctc_wide[i], conformer_ctc_wide, logprobs_ctc_wide[i], encoded_len[i])
        )
        
        scores_las.append(curr_scores_las)
        scores_ctc.append(curr_scores_ctc)
        scores_ctc_wide.append(curr_scores_ctc_wide)
        
        

MANIFEST_PATH:  /Users/alexandrmaximenko/mipt_tasks/speech-tech-mipt/week09/asr/data/test_opus/farfield/farfield.jsonl
MANIFEST_PATH:  /Users/alexandrmaximenko/mipt_tasks/speech-tech-mipt/week09/asr/data/test_opus/farfield/farfield.jsonl


  0%|          | 0/479 [00:00<?, ?it/s]

In [648]:
hyps_mbr = []
for i in tqdm(range(len(refs))):
    hyps = (hyps_las[i], hyps_ctc[i], hyps_ctc_wide[i])
    scores = (scores_las[i], scores_ctc[i], scores_ctc_wide[i])
    if None in scores_ctc[i]:
        hyps = (hyps_ctc[i], hyps_ctc_wide[i])
        scores = (scores_ctc[i][1:], scores_ctc_wide[i][1:])
    else:
        hyps = (hyps_las[i], hyps_ctc[i], hyps_ctc_wide[i])
        scores = (scores_las[i], scores_ctc[i], scores_ctc_wide[i])
        
    hyps_mbr.append(get_mbr_hyp(hyps, scores))

  0%|          | 0/1916 [00:00<?, ?it/s]

In [684]:
print('WER c использованием MBR: ', compute_wer(refs, hyps_mbr))

WER c использованием MBR:  0.3518253266811371


In [685]:
hyps_mbr_weighted = []
for i in tqdm(range(len(refs))):
    hyps = (hyps_las[i], hyps_ctc[i], hyps_ctc_wide[i])
    scores = (scores_las[i], scores_ctc[i], scores_ctc_wide[i])
    weights = [1.5, 1, 1.5]
    if None in scores_ctc[i]:
        hyps = (hyps_ctc[i], hyps_ctc_wide[i])
        scores = (scores_ctc[i][1:], scores_ctc_wide[i][1:])
        weights = weights[1:]
    else:
        hyps = (hyps_las[i], hyps_ctc[i], hyps_ctc_wide[i])
        scores = (scores_las[i], scores_ctc[i], scores_ctc_wide[i])
        
        
    hyps_mbr_weighted.append(get_mbr_hyp(hyps, scores, weights=weights))

print('WER с подобранными руками весами для MBR: ', compute_wer(refs, hyps_mbr_weighted))

  0%|          | 0/1916 [00:00<?, ?it/s]

WER с подобранными руками весами для MBR:  0.344547837972641


Посмотрим еще на OracleWER, интересно знать сколько вообще можно вытянуть из наших гипотез.

In [681]:
class OracleWER(WER):
    def __init__(self):

        super().__init__()

    def update(self, references: List[str], hypotheses: List[List[str]]):
        word_errors = 0.0
        words = 0.0
        for i, ref in enumerate(references):
            ref_tokens = ref.split()
            curr_hyps = [hyps[i] for hyps in hypotheses]
            hyp_tokens = [hyp.split() for hyp in curr_hyps]
            dist = min(editdistance.eval(ref_tokens, curr_hyp_tokens) for curr_hyp_tokens in hyp_tokens)
            word_errors += dist
            words += len(ref_tokens)
        self.word_errors = torch.tensor(
            word_errors, device=self.word_errors.device, dtype=self.word_errors.dtype
        )
        self.words = torch.tensor(
            words, device=self.words.device, dtype=self.words.dtype
        )

In [682]:
def compute_oracle_wer(refs: Iterable[str], hyps: Iterable[str]) -> float:
    wer = OracleWER()
    wer.update(refs, hyps)
    return wer.compute()[0].item()

In [686]:
print('Oracle WER: ', compute_oracle_wer(refs, [hyps_las, hyps_ctc, hyps_ctc_wide]))

Oracle WER:  0.282390832901001
