In [1]:
#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
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [2]:
from datasets import load_dataset, Audio, Features, Value

features = Features({
    "file": Value("string"),
    "audio": Audio(decode=False), 
    "text": Value("string"),
    "speaker_id": Value("int64"),
    "chapter_id": Value("int64"),
    "id": Value("string"),
})
train_clean = load_dataset("librispeech_asr", "clean", split="train.360", features=features, streaming=True)
train_other = load_dataset("librispeech_asr", "other", split="train.500", features=features, streaming=True)

val_clean = load_dataset("librispeech_asr", "clean", split="validation", features=features, streaming=True)
val_other = load_dataset("librispeech_asr", "other", split="validation", features=features, streaming=True)

test_clean = load_dataset("librispeech_asr", "clean", split="test", features=features, streaming=True)
test_other = load_dataset("librispeech_asr", "other", split="test", features=features, streaming=True)

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

In [None]:
import torch
from tqdm import tqdm
import wandb

class Inferencer:
    def __init__(
        self,
        encoder,
        decoder,
        tokenizer,
        device="cuda",
        pad_id=50256,
        use_wandb=False,
        wandb_project=None,
    ):
        self.encoder = encoder
        self.decoder = decoder
        self.tokenizer = tokenizer
        self.device = device
        self.pad_id = pad_id

        self.compute_log_melspectrogram = compute_log_melspectrogram
        self.compute_downsampled_len = compute_downsampled_len
        self.tokens_to_text = tokens_to_text
        self.levenshtein = levenshtein

        self.use_wandb = use_wandb
        self.wandb_run = None
        self._wandb_table = None
        if self.use_wandb:
            self.wandb_project = wandb_project or "asr-eval"
        else:
            self.wandb_project = None

    @torch.no_grad()
    def generate_greedy(self, enc_out, cross_key_padding_mask=None, **kwargs):
        return self.decoder.autoregressive_decode(enc_out, cross_key_padding_mask=cross_key_padding_mask, device=self.device, **kwargs)

    @torch.no_grad()
    def generate_beam(self, enc_out, beam_size=5, cross_key_padding_mask=None, **kwargs):
        return self.decoder.beam_search_decode(enc_out, beam_size=beam_size, cross_key_padding_mask=cross_key_padding_mask, device=self.device, **kwargs)

    def _compute_batch_metrics_single_pred(self, texts, pred_tokens):
        B = texts.size(0)
        total_word_edits = 0
        total_ref_words = 0
        total_char_edits = 0
        total_ref_chars = 0
        rows = []

        for i in range(B):
            ref_ids = texts[i].cpu()
            ref_text = self.tokens_to_text(ref_ids, self.tokenizer, pad_id=self.pad_id).strip()
            ref_text = " ".join(ref_text.split())

            pred_ids = pred_tokens[i].cpu()
            pred_text = self.tokens_to_text(pred_ids, self.tokenizer, pad_id=self.pad_id).strip()
            pred_text = " ".join(pred_text.split())

            words_ref = ref_text.split()
            words_pred = pred_text.split()
            edits_words = self.levenshtein(words_ref, words_pred)
            total_word_edits += edits_words
            total_ref_words += len(words_ref)

            chars_ref = list(ref_text)
            chars_pred = list(pred_text)
            edits_chars = self.levenshtein(chars_ref, chars_pred)
            total_char_edits += edits_chars
            total_ref_chars += len(chars_ref)

            wer_sample = edits_words / max(1, len(words_ref))
            cer_sample = edits_chars / max(1, max(1, len(chars_ref)))

            rows.append({
                "ref": ref_text,
                "pred": pred_text,
                "wer": wer_sample,
                "cer": cer_sample
            })

        batch_summary = {
            "batch_word_edits": total_word_edits,
            "batch_ref_words": total_ref_words,
            "batch_char_edits": total_char_edits,
            "batch_ref_chars": total_ref_chars,
            "batch_size": B
        }
        return batch_summary, rows

    def compare(self, texts, greedy_tokens, beam_tokens, dataset_name, batch_idx):
        B = texts.size(0)
        total_word_edits_g = 0
        total_word_edits_b = 0
        total_ref_words = 0

        total_char_edits_g = 0
        total_char_edits_b = 0
        total_ref_chars = 0

        rows = []

        for i in range(B):
            ref_ids = texts[i].cpu()
            ref_text = self.tokens_to_text(ref_ids, self.tokenizer, pad_id=self.pad_id).strip()
            ref_text = " ".join(ref_text.split())

            pred_g_ids = greedy_tokens[i].cpu()
            pred_g_text = self.tokens_to_text(pred_g_ids, self.tokenizer, pad_id=self.pad_id).strip()
            pred_g_text = " ".join(pred_g_text.split())

            pred_b_ids = beam_tokens[i].cpu()
            pred_b_text = self.tokens_to_text(pred_b_ids, self.tokenizer, pad_id=self.pad_id).strip()
            pred_b_text = " ".join(pred_b_text.split())

            words_ref = ref_text.split()
            words_g = pred_g_text.split()
            words_b = pred_b_text.split()
            edits_g_words = self.levenshtein(words_ref, words_g)
            edits_b_words = self.levenshtein(words_ref, words_b)

            total_word_edits_g += edits_g_words
            total_word_edits_b += edits_b_words
            total_ref_words += len(words_ref)

            chars_ref = list(ref_text)
            chars_g = list(pred_g_text)
            chars_b = list(pred_b_text)
            edits_g_chars = self.levenshtein(chars_ref, chars_g)
            edits_b_chars = self.levenshtein(chars_ref, chars_b)

            total_char_edits_g += edits_g_chars
            total_char_edits_b += edits_b_chars
            total_ref_chars += len(chars_ref)

            wer_g_per_sample = edits_g_words / max(1, len(words_ref))
            cer_g_per_sample = edits_g_chars / max(1, max(1, len(chars_ref)))
            wer_b_per_sample = edits_b_words / max(1, len(words_ref))
            cer_b_per_sample = edits_b_chars / max(1, max(1, len(chars_ref)))

            rows.append({
                "dataset": dataset_name,
                "batch_idx": batch_idx,
                "sample_idx_in_batch": i,
                "ref": ref_text,
                "pred_greedy": pred_g_text,
                "wer_greedy": wer_g_per_sample,
                "cer_greedy": cer_g_per_sample,
                "pred_beam": pred_b_text,
                "wer_beam": wer_b_per_sample,
                "cer_beam": cer_b_per_sample,
            })

        batch_summary = {
            "batch_word_edits_g": total_word_edits_g,
            "batch_word_edits_b": total_word_edits_b,
            "batch_ref_words": total_ref_words,
            "batch_char_edits_g": total_char_edits_g,
            "batch_char_edits_b": total_char_edits_b,
            "batch_ref_chars": total_ref_chars,
            "batch_size": B,
        }

        return batch_summary, rows

    def evaluate_dataset(self, loader, dataset_name, beam_size=5, max_eval_batches=None, log_first_per_batch_only=True, wandb_run=None):
        if self.use_wandb and wandb_run is None:
            wandb_run = wandb.init(project=getattr(self, "wandb_project", "asr-eval"), reinit=True)
            self.wandb_run = wandb_run

        rows_for_table = []

        self.encoder.eval()
        self.decoder.eval()

        total_word_edits_bs = 0
        total_word_edits_g = 0
        total_ref_words = 0

        total_char_edits_bs = 0
        total_char_edits_g = 0
        total_ref_chars = 0

        batch_idx = -1
        with torch.no_grad():
            for batch in tqdm(loader, desc=f"eval_corpus {dataset_name}"):
                batch_idx += 1
                if max_eval_batches is not None and batch_idx >= max_eval_batches:
                    break

                audio = batch['audio'].to(self.device)
                texts = batch['text'].to(self.device)
                audio_lengths = batch['audio_lengths'].to(self.device)
                sr_batch = batch.get('sr', 16000)

                B = audio.size(0)

                mels, mel_lens = self.compute_log_melspectrogram(audio, audio_lengths, sr=sr_batch, device=self.device, augment=False)
                mels = mels.transpose(2, 1)
                padded_mel_len = mels.size(1)

                conv1_padded = self.compute_downsampled_len(torch.tensor(padded_mel_len, device=self.device))
                max_enc_len = self.compute_downsampled_len(conv1_padded)
                conv1_lens = self.compute_downsampled_len(mel_lens)
                enc_lens = self.compute_downsampled_len(conv1_lens)

                enc_key_padding_mask = torch.ones(B, max_enc_len, device=self.device, dtype=torch.bool)
                for i in range(B):
                    enc_key_padding_mask[i, :enc_lens[i]] = False

                enc_out = self.encoder(mels.unsqueeze(1), key_padding_mask=enc_key_padding_mask)

                greedy_tokens = self.generate_greedy(enc_out, cross_key_padding_mask=enc_key_padding_mask)
                beam_tensor = self.generate_beam(enc_out, beam_size=beam_size, cross_key_padding_mask=None)
                beam_tokens = beam_tensor[:, 0, :]

                batch_summary, rows = self.compare(texts, greedy_tokens, beam_tokens, dataset_name, batch_idx)

                total_word_edits_g += batch_summary["batch_word_edits_g"]
                total_word_edits_bs += batch_summary["batch_word_edits_b"]
                total_ref_words += batch_summary["batch_ref_words"]

                total_char_edits_g += batch_summary["batch_char_edits_g"]
                total_char_edits_bs += batch_summary["batch_char_edits_b"]
                total_ref_chars += batch_summary["batch_ref_chars"]

                if log_first_per_batch_only:
                    first_row = rows[0]
                    rows_for_table.append(first_row)
                    if self.use_wandb:
                        if self.wandb_run is None:
                            self.wandb_run = wandb_run
                        if not hasattr(self, "_wandb_table") or self._wandb_table is None:
                            self._wandb_table = wandb.Table(columns=[
                                "dataset", "batch_idx", "sample_idx_in_batch",
                                "target", "predict_beam", "wer_beam", "cer_beam",
                                "predict_greedy", "wer_greedy", "cer_greedy"
                            ], log_mode='MUTABLE')
                        self._wandb_table.add_data(
                            first_row["dataset"],
                            first_row["batch_idx"],
                            first_row["sample_idx_in_batch"],
                            first_row["ref"],
                            first_row["pred_beam"],
                            first_row["wer_beam"],
                            first_row["cer_beam"],
                            first_row["pred_greedy"],
                            first_row["wer_greedy"],
                            first_row["cer_greedy"],
                        )
                        self.wandb_run.log({f"eval_table/{dataset_name}_corpus": self._wandb_table})
                else:
                    rows_for_table.extend(rows)

        wer_beam_corpus = total_word_edits_bs / total_ref_words if total_ref_words > 0 else 0.0
        wer_greedy_corpus = total_word_edits_g / total_ref_words if total_ref_words > 0 else 0.0
        cer_beam_corpus = total_char_edits_bs / total_ref_chars if total_ref_chars > 0 else 0.0
        cer_greedy_corpus = total_char_edits_g / total_ref_chars if total_ref_chars > 0 else 0.0

        log_dict = {
            f"{dataset_name}/wer_beam": wer_beam_corpus,
            f"{dataset_name}/wer_greedy": wer_greedy_corpus,
            f"{dataset_name}/cer_beam": cer_beam_corpus,
            f"{dataset_name}/cer_greedy": cer_greedy_corpus,
        }

        if self.use_wandb and self.wandb_run is not None:
            self.wandb_run.log(log_dict)

        final = {
            "dataset": dataset_name,
            "wer_beam": wer_beam_corpus,
            "cer_beam": cer_beam_corpus,
            "wer_greedy": wer_greedy_corpus,
            "cer_greedy": cer_greedy_corpus,
        }

        return final, rows_for_table

    def evaluate_with_greedy(self, loader, dataset_name, max_eval_batches=None, log_first_per_batch_only=True, wandb_run=None):
        if self.use_wandb and wandb_run is None:
            wandb_run = wandb.init(project=getattr(self, "wandb_project", "asr-eval"), reinit=True)
            self.wandb_run = wandb_run

        rows_for_table = []
        self.encoder.eval()
        self.decoder.eval()

        total_word_edits = 0
        total_ref_words = 0
        total_char_edits = 0
        total_ref_chars = 0

        batch_idx = -1
        with torch.no_grad():
            for batch in tqdm(loader, desc=f"greedy_eval {dataset_name}"):
                batch_idx += 1
                if max_eval_batches is not None and batch_idx >= max_eval_batches:
                    break

                audio = batch['audio'].to(self.device)
                texts = batch['text'].to(self.device)
                audio_lengths = batch['audio_lengths'].to(self.device)
                sr_batch = batch.get('sr', 16000)

                B = audio.size(0)
                mels, mel_lens = self.compute_log_melspectrogram(audio, audio_lengths, sr=sr_batch, device=self.device, augment=False)
                mels = mels.transpose(2, 1)
                padded_mel_len = mels.size(1)

                conv1_padded = self.compute_downsampled_len(torch.tensor(padded_mel_len, device=self.device))
                max_enc_len = self.compute_downsampled_len(conv1_padded)
                conv1_lens = self.compute_downsampled_len(mel_lens)
                enc_lens = self.compute_downsampled_len(conv1_lens)

                enc_key_padding_mask = torch.ones(B, max_enc_len, device=self.device, dtype=torch.bool)
                for i in range(B):
                    enc_key_padding_mask[i, :enc_lens[i]] = False

                enc_out = self.encoder(mels.unsqueeze(1), key_padding_mask=enc_key_padding_mask)

                greedy_tokens = self.generate_greedy(enc_out, cross_key_padding_mask=enc_key_padding_mask)
                batch_summary, rows = self._compute_batch_metrics_single_pred(texts, greedy_tokens)

                total_word_edits += batch_summary["batch_word_edits"]
                total_ref_words += batch_summary["batch_ref_words"]
                total_char_edits += batch_summary["batch_char_edits"]
                total_ref_chars += batch_summary["batch_ref_chars"]

                if log_first_per_batch_only:
                    first_row = rows[0]
                    rows_for_table.append(first_row)
                    if self.use_wandb:
                        if self.wandb_run is None:
                            self.wandb_run = wandb_run
                        if not hasattr(self, "_wandb_table") or self._wandb_table is None:
                            self._wandb_table = wandb.Table(columns=[
                                "dataset", "batch_idx", "sample_idx_in_batch",
                                "target", "predict", "wer", "cer"
                            ], log_mode='MUTABLE')
                        self._wandb_table.add_data(
                            dataset_name,
                            batch_idx,
                            0,
                            first_row["ref"],
                            first_row["pred"],
                            first_row["wer"],
                            first_row["cer"],
                        )
                        self.wandb_run.log({f"eval_table/{dataset_name}_greedy": self._wandb_table})
                else:
                    rows_for_table.extend(rows)

        wer_greedy_corpus = total_word_edits / total_ref_words if total_ref_words > 0 else 0.0
        cer_greedy_corpus = total_char_edits / total_ref_chars if total_ref_chars > 0 else 0.0

        log_dict = {
            f"{dataset_name}/wer_greedy": wer_greedy_corpus,
            f"{dataset_name}/cer_greedy": cer_greedy_corpus,
        }
        if self.use_wandb and self.wandb_run is not None:
            self.wandb_run.log(log_dict)

        final = {
            "dataset": dataset_name,
            "wer_greedy": wer_greedy_corpus,
            "cer_greedy": cer_greedy_corpus,
        }
        return final, rows_for_table

    def evaluate_with_beam(self, loader, dataset_name, beam_size=5, max_eval_batches=None, log_first_per_batch_only=True, wandb_run=None):
        if self.use_wandb and wandb_run is None:
            wandb_run = wandb.init(project=getattr(self, "wandb_project", "asr-eval"), reinit=True)
            self.wandb_run = wandb_run

        rows_for_table = []
        self.encoder.eval()
        self.decoder.eval()

        total_word_edits = 0
        total_ref_words = 0
        total_char_edits = 0
        total_ref_chars = 0

        batch_idx = -1
        with torch.no_grad():
            for batch in tqdm(loader, desc=f"beam_eval {dataset_name}"):
                batch_idx += 1
                if max_eval_batches is not None and batch_idx >= max_eval_batches:
                    break

                audio = batch['audio'].to(self.device)
                texts = batch['text'].to(self.device)
                audio_lengths = batch['audio_lengths'].to(self.device)
                sr_batch = batch.get('sr', 16000)

                B = audio.size(0)
                mels, mel_lens = self.compute_log_melspectrogram(audio, audio_lengths, sr=sr_batch, device=self.device, augment=False)
                mels = mels.transpose(2, 1)
                padded_mel_len = mels.size(1)

                conv1_padded = self.compute_downsampled_len(torch.tensor(padded_mel_len, device=self.device))
                max_enc_len = self.compute_downsampled_len(conv1_padded)
                conv1_lens = self.compute_downsampled_len(mel_lens)
                enc_lens = self.compute_downsampled_len(conv1_lens)

                enc_key_padding_mask = torch.ones(B, max_enc_len, device=self.device, dtype=torch.bool)
                for i in range(B):
                    enc_key_padding_mask[i, :enc_lens[i]] = False

                enc_out = self.encoder(mels.unsqueeze(1), key_padding_mask=enc_key_padding_mask)

                beam_tensor = self.generate_beam(enc_out, beam_size=beam_size, cross_key_padding_mask=None)
                beam_tokens = beam_tensor[:, 0, :]
                batch_summary, rows = self._compute_batch_metrics_single_pred(texts, beam_tokens)

                total_word_edits += batch_summary["batch_word_edits"]
                total_ref_words += batch_summary["batch_ref_words"]
                total_char_edits += batch_summary["batch_char_edits"]
                total_ref_chars += batch_summary["batch_ref_chars"]

                if log_first_per_batch_only:
                    first_row = rows[0]
                    rows_for_table.append(first_row)
                    if self.use_wandb:
                        if self.wandb_run is None:
                            self.wandb_run = wandb_run
                        if not hasattr(self, "_wandb_table") or self._wandb_table is None:
                            self._wandb_table = wandb.Table(columns=[
                                "dataset", "batch_idx", "sample_idx_in_batch",
                                "target", "predict", "wer", "cer"
                            ], log_mode='MUTABLE')
                        self._wandb_table.add_data(
                            dataset_name,
                            batch_idx,
                            0,
                            first_row["ref"],
                            first_row["pred"],
                            first_row["wer"],
                            first_row["cer"],
                        )
                        self.wandb_run.log({f"eval_table/{dataset_name}_beam": self._wandb_table})
                else:
                    rows_for_table.extend(rows)

        wer_beam_corpus = total_word_edits / total_ref_words if total_ref_words > 0 else 0.0
        cer_beam_corpus = total_char_edits / total_ref_chars if total_ref_chars > 0 else 0.0

        log_dict = {
            f"{dataset_name}/wer_beam": wer_beam_corpus,
            f"{dataset_name}/cer_beam": cer_beam_corpus,
        }
        if self.use_wandb and self.wandb_run is not None:
            self.wandb_run.log(log_dict)

        final = {
            "dataset": dataset_name,
            "wer_beam": wer_beam_corpus,
            "cer_beam": cer_beam_corpus,
        }
        return final, rows_for_table

    def evaluate_both_and_log(self, loader_clean, loader_other, beam_size=4, max_eval_batches=None, wandb_run=None):
        if self.use_wandb and wandb_run is None:
            wandb_run = wandb.init(project=getattr(self, "wandb_project", "asr-eval"), reinit=True)
            self.wandb_run = wandb_run

        res_other, table_other = self.evaluate_dataset(loader_other, "other", beam_size=beam_size, max_eval_batches=max_eval_batches, wandb_run=wandb_run)
        res_clean, table_clean = self.evaluate_dataset(loader_clean, "clean", beam_size=beam_size, max_eval_batches=max_eval_batches, wandb_run=wandb_run)

        log_payload = {
            "final/clean/wer_beam": res_clean["wer_beam"],
            "final/clean/cer_beam": res_clean["cer_beam"],
            "final/clean/wer_greedy": res_clean["wer_greedy"],
            "final/clean/cer_greedy": res_clean["cer_greedy"],
            "final/other/wer_beam": res_other["wer_beam"],
            "final/other/cer_beam": res_other["cer_beam"],
            "final/other/wer_greedy": res_other["wer_greedy"],
            "final/other/cer_greedy": res_other["cer_greedy"],
        }
        if self.use_wandb and self.wandb_run is not None:
            self.wandb_run.log(log_payload)

        return {
            "clean": res_clean,
            "other": res_other,
            "table_clean": table_clean,
            "table_other": table_other
        }


aa
