wav2vec 2.0
- `wav2vec2-large` pretrained: https://huggingface.co/facebook/wav2vec2-large-lv60
- `wav2vec2-large` fine-tuned: https://huggingface.co/facebook/wav2vec2-large-960h-lv60

Whisper
- `whisper-large-v2`: https://huggingface.co/openai/whisper-large-v2

To precompute transcripts and other expensive things, run the notebook via ipython from the terminal:

```bash
ipython --to python --convert "ssl_uncertainty.ipynb"

```

In [None]:
WAV2VEC_LARGE_PRETRAINED = "facebook/wav2vec2-large-lv60"
WAV2VEC_LARGE_FINETUNED = "facebook/wav2vec2-large-960h-lv60"
WHISPER_LARGE_V2 = "openai/whisper-large-v2"

In [None]:
HUGGINFACE_HOME = "/m2/research/huggingface"
TEMP_SAVE_DIR = "/m2/research/jdh/thesis/"

In [None]:
import os

os.environ["HF_HOME"] = HUGGINFACE_HOME
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import functools

from typing import List, Dict
from copy import deepcopy
from collections import defaultdict

import torch
import numba
import numpy as np
import rich
import matplotlib.pyplot as plt
import scipy
import pandas as pd
import jiwer

from datasets import load_dataset, load_from_disk
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor, Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Config
from torch.utils.data import DataLoader
from jiwer import wer, cer
from tqdm import tqdm
from IPython.display import HTML, display

In [None]:
print(torch.cuda.is_available(), torch.cuda.device_count())

device = torch.device('cuda:0')

torch.set_default_device(device)

## Dataset 

In [None]:
def collate_fn(batch, feature_extractor, tokenizer=None):
    files = [b["file"] for b in batch]
    audios = [b["audio"]["array"] for b in batch]
    texts = [b["text"] for b in batch]
    speaker_ids = [b["speaker_id"] for b in batch]
    chapter_ids = [b["chapter_id"] for b in batch]
    overall_ids = [b["id"] for b in batch]
    
    features = feature_extractor(audios, sampling_rate=16_000, padding="longest", return_tensors="np", return_attention_mask=True, )
    audios, attention_mask = features.input_values, features.attention_mask
    
    audios = torch.from_numpy(audios)
    attention_mask = torch.from_numpy(attention_mask)
    
    out_batch = {
        "input_values": audios,
        "attention_mask": attention_mask,
        "speaker_ids": speaker_ids,
        "chapter_ids": chapter_ids,
        "ids": overall_ids,
        "files": files,
        "texts": texts,
    }
    if all("length" in b for b in batch):
        lengths = [b["length"] for b in batch]
        out_batch["lengths"] = torch.as_tensor(lengths, device="cpu")
    
    if tokenizer is not None:
        labels = tokenizer(texts, return_tensors="np", padding="longest", return_attention_mask=False).input_ids
        labels = torch.from_numpy(labels)
        out_batch["labels"] = labels

    return out_batch

In [None]:
@torch.autocast(device_type="cuda")
@torch.inference_mode()
def transcribe_batch(batch, model, tokenizer):
    input_values = batch["input_values"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    logits = model(input_values, attention_mask=attention_mask).logits
    
    pred_ids = torch.argmax(logits, dim=-1)
    transcripts = tokenizer.batch_decode(pred_ids)
    
    # remove padding from logits
    lens = batch["lengths"] // 320  # 320 is the model hop length
    logits = logits.cpu().unbind(0)
    logits = [logits[i][:lens[i], :] for i in range(len(lens))]
    return batch["ids"], transcripts, logits

# Uncertainty in transcripts

In [None]:
librispeech_test_clean = load_dataset("librispeech_asr", "clean", split="test")

In [None]:
def compute_length(example):
    example["length"] = len(example["audio"]["array"])
    return example

librispeech_test_clean = librispeech_test_clean.map(compute_length)

In [None]:
w2v2_model = Wav2Vec2ForCTC.from_pretrained(WAV2VEC_LARGE_FINETUNED).to(device)
w2v2_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(WAV2VEC_LARGE_FINETUNED)
w2v2_tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(WAV2VEC_LARGE_FINETUNED)
# processor = Wav2Vec2Processor.from_pretrained(WAV2VEC_LARGE_FINETUNED)

In [None]:
collate_fn_w2v2 = functools.partial(collate_fn, tokenizer=w2v2_tokenizer, feature_extractor=w2v2_feature_extractor)
dataloader = DataLoader(librispeech_test_clean, batch_size=1, collate_fn=collate_fn_w2v2, num_workers=0)
iterator = iter(dataloader)
next(iterator)

In [None]:
librispeech_test_clean[:2]

In [None]:
# Transcribe the dataset and save the results (load if exists already)

LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR = os.path.join(TEMP_SAVE_DIR, "librispeech_test_clean_with_w2v2_transcripts")

try:
    librispeech_test_clean = load_from_disk(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR)
    assert "transcript" in librispeech_test_clean.column_names
    logits = torch.load(os.path.join(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR, "logits.pt"), map_location="cpu")
except (AssertionError, FileNotFoundError):
    ids, transcripts, logits = transcribe_dataset(dataloader, w2v2_model, w2v2_tokenizer)
    librispeech_test_clean = librispeech_test_clean.add_column("transcript", transcripts)
    librispeech_test_clean.save_to_disk(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR)
    torch.save(logits, os.path.join(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR, "logits.pt"), )

In [None]:
def compute_mc_transcripts(dataloader, model, tokenizer, num_samples):
    """Compute Monte Carlo transcripts for a dataset `num_samples` times."""
    model.train()  # set to train mode to enable dropout
    
    mc_ids = []
    mc_transcripts = []
    mc_logits = []
    
    # bar = tqdm(reversed(range(num_samples)), desc="Getting MC Transcripts")
    bar = tqdm(range(num_samples), desc="Getting MC Transcripts")
    for i in bar:
        bar.set_description(f"Getting MC Transcripts {i:03d}")

        try:
            ids = torch.load(os.path.join(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR, f"mc_ids_{i:03d}.pt"), map_location="cpu")
            transcripts = torch.load(os.path.join(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR, f"mc_transcript_{i:03d}.pt"), map_location="cpu")
            logits = torch.load(os.path.join(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR, f"mc_logits_{i:03d}.pt"), map_location="cpu")
        except FileNotFoundError as exc:
            print(exc)
            ids, transcripts, logits = transcribe_dataset(dataloader, model, tokenizer)
            torch.save(ids, os.path.join(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR, f"mc_ids_{i:03d}.pt"))
            torch.save(transcripts, os.path.join(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR, f"mc_transcript_{i:03d}.pt"))
            torch.save(logits, os.path.join(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR, f"mc_logits_{i:03d}.pt"))

        mc_ids.append(ids)
        mc_transcripts.append(transcripts)
        mc_logits.append(logits)

    return mc_ids, mc_transcripts, mc_logits

In [None]:
# Compute and add Monte Carlo Dropout simulations
NUM_SAMPLES = 256

mc_ids, mc_transcripts, mc_logits = compute_mc_transcripts(dataloader, w2v2_model, w2v2_tokenizer, num_samples=NUM_SAMPLES)


In [None]:
# raise Exception("Stop here")

In [None]:
@numba.jit(nopython=True)
def levenstein_distance(reference: List[str], prediction: List[str]) -> np.ndarray:
    """Compute the Levenstein distance between two tokenized strings."""
    
    # Create a matrix to store alignment costs
    alignment_matrix = np.zeros((len(reference) + 1, len(prediction) + 1))
    
    # Initialize the matrix with deletion costs for reference words
    alignment_matrix[:, 0] = np.arange(len(reference) + 1)
    
    # Initialize the matrix with insertion costs for Monte Carlo words
    alignment_matrix[0, :] = np.arange(len(prediction) + 1)
    
    # Fill in the alignment matrix
    for i in range(1, len(reference) + 1):
        for j in range(1, len(prediction) + 1):
            cost = 0 if reference[i - 1] == prediction[j - 1] else 1
            alignment_matrix[i][j] = min(
                alignment_matrix[i - 1][j] + 1,        # Deletion
                alignment_matrix[i][j - 1] + 1,        # Insertion
                alignment_matrix[i - 1][j - 1] + cost  # Substitution
            )
    
    return alignment_matrix


def align_transcripts_charlevel(reference: str, prediction: str) -> List[CharAlignmentElement]:
    reference_chars = list(reference)
    prediction_chars = list(prediction)
    
    alignment_matrix = levenstein_distance(reference_chars, prediction_chars)
    
    # Backtrace to find the alignment
    i, j = len(reference_chars), len(prediction_chars)
    alignment = []
    
    while i > 0 or j > 0:
        if i > 0 and alignment_matrix[i][j] == alignment_matrix[i - 1][j] + 1:  # Deletion
            alignment.insert(0, CharAlignmentElement("OMITTED", "", reference_chars[i - 1]))
            i -= 1
        elif j > 0 and alignment_matrix[i][j] == alignment_matrix[i][j - 1] + 1:  # Insertion
            alignment.insert(0, CharAlignmentElement("EXTRA", prediction_chars[j - 1], ""))
            j -= 1
        else:
            if reference_chars[i - 1] != prediction_chars[j - 1]:  # Substitution
                alignment.insert(0, CharAlignmentElement("MISSPELLED", prediction_chars[j - 1], reference_chars[i - 1]))
            else:  # Match
                alignment.insert(0, CharAlignmentElement("MATCHED", prediction_chars[j - 1], reference_chars[i - 1]))
            i -= 1
            j -= 1
    
    return alignment


MISSING_INDICATOR = "□"
WHITESPACE = " "
WHITESPACE_TOKEN = "·"
PUNCTUATION = [".", ",", "!", "?", ":", ";", "'", "(", ")", "[", "]", "{", "}", "-", "_", "/", "\\", "|", "@", "#", "$", "%", "^", "&", "*", "~", "`", "+", "=", "<", ">"]


class CharAlignmentElement:
    """Class to store an element of an alignment between a reference and a Monte Carlo transcript."""
    def __init__(self, type_, hyp_char, ref_char):
        self.type = type_
        self.hyp_char = hyp_char if hyp_char != WHITESPACE else WHITESPACE_TOKEN
        self.ref_char = ref_char if ref_char != WHITESPACE else WHITESPACE_TOKEN
    
    def __repr__(self):
        return f"CharAlignmentElement(type_='{self.type}', hyp_char='{self.hyp_char}', ref_char='{self.ref_char}')"


class WordAlignmentElement:
    def __init__(self, char_alignment: List[CharAlignmentElement]):
        self.char_alignment = char_alignment
        
        self.hyp_chars = [e.hyp_char for e in char_alignment]
        self.ref_chars = [e.ref_char for e in char_alignment]

        self.hyp_word = "".join(self.hyp_chars).strip(WHITESPACE_TOKEN)  # trailing whitespace is not part of word
        self.ref_word = "".join(self.ref_chars).strip(WHITESPACE_TOKEN)  # trailing whitespace is not part of word
        
        if self.hyp_word != "" and self.ref_word != "":
            self.cer = cer(
                self.ref_word,
                self.hyp_word,
                reference_transform=jiwer.transforms.ReduceToListOfListOfChars(),
                hypothesis_transform=jiwer.transforms.ReduceToListOfListOfChars()
            )
        elif self.hyp_word == "" and self.ref_word == "":
            self.cer = 0.0
        elif self.ref_word == "":
            self.cer = np.inf
        else:
            self.cer = 1.0

        if self.hyp_word == self.ref_word:
            self.type = "MATCHED"
        elif self.hyp_word == "":
            self.type = "OMITTED"
        elif self.ref_word == "":
            self.type = "EXTRA"
        else:
            self.type = "MISSPELLED"

        self.hyp_word_aligned = ""
        self.ref_word_aligned = ""
        self.hyp_word_aligned_colored = ""
        self.ref_word_aligned_colored = ""

        for element in self.char_alignment:
            if element.type == "MATCHED":
                self.ref_word_aligned += element.ref_char
                self.hyp_word_aligned += element.hyp_char
                self.ref_word_aligned_colored += element.ref_char
                self.hyp_word_aligned_colored += element.hyp_char
            elif element.type == "MISSPELLED":
                self.ref_word_aligned += element.ref_char
                self.hyp_word_aligned += element.hyp_char
                self.ref_word_aligned_colored += f"[orange1]{element.ref_char}[/orange1]"
                self.hyp_word_aligned_colored += f"[orange1]{element.hyp_char}[/orange1]"
            elif element.type == "OMITTED":
                self.ref_word_aligned += element.ref_char
                self.hyp_word_aligned += f"{MISSING_INDICATOR}"
                self.ref_word_aligned_colored += f"[red]{element.ref_char}[/red]"
                self.hyp_word_aligned_colored += f"[red]{MISSING_INDICATOR}[/red]"
            elif element.type == "EXTRA":
                self.ref_word_aligned += f"{MISSING_INDICATOR}"
                self.hyp_word_aligned += element.hyp_char
                self.ref_word_aligned_colored += f"[green]{MISSING_INDICATOR}[/green]"
                self.hyp_word_aligned_colored += f"[green]{element.hyp_char}[/green]"

        self.hyp_word_aligned = self.hyp_word_aligned.strip(WHITESPACE_TOKEN)
        self.ref_word_aligned = self.ref_word_aligned.strip(WHITESPACE_TOKEN)
        self.hyp_word_aligned_colored = self.hyp_word_aligned_colored.strip(WHITESPACE_TOKEN)
        self.ref_word_aligned_colored = self.ref_word_aligned_colored.strip(WHITESPACE_TOKEN)

    def print_raw(self):
        rich.print(self.ref_word)
        rich.print(self.hyp_word)

    def print_aligned(self):
        rich.print(self.ref_word_aligned)
        rich.print(self.hyp_word_aligned)

    def print_aligned_colored(self):
        rich.print(self.ref_word_aligned_colored)
        rich.print(self.hyp_word_aligned_colored)

    def __repr__(self):
        s = "WordAlignmentElement(\n"
        s += f"    ref_word='{self.ref_word_aligned}',\n"
        s += f"    hyp_word='{self.hyp_word_aligned}',\n"
        s += f"    type_='{self.type}',\n"
        s += f"    cer={self.cer:.2f}\n"
        s += ")"
        return s


class Alignment:
    """Class to store an alignment between a reference and a Monte Carlo transcript."""
    def __init__(self, reference: str, prediction: str):
        self.reference = reference.replace(WHITESPACE, WHITESPACE_TOKEN)
        self.prediction = prediction.replace(WHITESPACE, WHITESPACE_TOKEN)

        if len(prediction) == 0:
            self.char_alignments = [CharAlignmentElement("OMITTED", "", char) for char in reference]
            self.wer = 1.0
            self.cer = 1.0
        elif len(reference) == 0:
            self.char_alignments = [CharAlignmentElement("EXTRA", char, "") for char in prediction]
            self.wer = np.inf
            self.cer = np.inf
        else:
            self.char_alignments = align_transcripts_charlevel(reference, prediction)
            self.wer = wer(reference, prediction)
            self.cer = cer(reference, prediction)

        self.word_alignments = []
        
        ref_ended = False
        hyp_ended = False
        i = 0
        for j in range(len(self.char_alignments)):
            ref_ended = ref_ended or self.char_alignments[j].ref_char == WHITESPACE_TOKEN
            hyp_ended = hyp_ended or self.char_alignments[j].hyp_char == WHITESPACE_TOKEN
            
            if ref_ended and hyp_ended:
                self.word_alignments.append(WordAlignmentElement(self.char_alignments[i:j]))
                i = j
                ref_ended = False
                hyp_ended = False

        self.word_alignments.append(WordAlignmentElement(self.char_alignments[i:]))

        self.reference_aligned = WHITESPACE_TOKEN.join([e.ref_word_aligned for e in self.word_alignments])
        self.prediction_aligned = WHITESPACE_TOKEN.join([e.hyp_word_aligned for e in self.word_alignments])

        self.reference_aligned_colored = WHITESPACE_TOKEN.join([e.ref_word_aligned_colored for e in self.word_alignments])
        self.prediction_aligned_colored = WHITESPACE_TOKEN.join([e.hyp_word_aligned_colored for e in self.word_alignments])

        # Compute corrected WER and missing/extra space error
        self.num_spaces = sum([1 for e in self.char_alignments if e.ref_char == WHITESPACE_TOKEN])
        self.num_space_substitutions = sum([1 for e in self.char_alignments if e.type == "MISSPELLED" and e.hyp_char == WHITESPACE_TOKEN])
        self.num_space_insertions = sum([1 for e in self.char_alignments if e.type == "EXTRA" and e.hyp_char == WHITESPACE_TOKEN])
        self.num_space_deletions = sum([1 for e in self.char_alignments if e.type == "OMITTED" and e.ref_char == WHITESPACE_TOKEN])
        print(self.num_spaces, self.num_space_substitutions, self.num_space_insertions, self.num_space_deletions)
        if self.num_spaces == 0:
            self.space_error_rate = np.inf
        else:
            self.space_error_rate = (self.num_space_substitutions + self.num_space_insertions + self.num_space_deletions) / self.num_spaces
        
        self.corrected_wer = 0.0
        self.missing_space_error = 0.0
        self.extra_space_error = 0.0
        for word_alignment in self.word_alignments:
            # if word_alignment.type == "MATCHED":
            #     self.corrected_wer += word_alignment.cer
            if word_alignment.type == "MISSPELLED":
                self.corrected_wer += 1.0
            elif word_alignment.type == "OMITTED":
                self.missing_space_error += 1.0
            elif word_alignment.type == "EXTRA":
                self.extra_space_error += 1.0
                
        self.corrected_wer /= len(self.word_alignments)

    def print_raw(self):
        rich.print(self.reference, self.prediction, sep="\n", end="\n")

    def print_aligned(self):
        rich.print(self.reference_aligned, self.prediction_aligned, sep="\n", end="\n")

    def print_aligned_colored(self):
        rich.print(self.reference_aligned_colored, self.prediction_aligned_colored, sep="\n", end="\n")

    def report_errors(self):
        s = ""
        s = s + f"WER  : {self.wer:.2f}\n"
        s = s + f"CER  : {self.cer:.2f}\n"
        s = s + f"CWER : {self.corrected_wer:.2f}\n"
        s = s + f"SER  : {self.space_error_rate:.2f}\n"
        rich.print(s)
        return s
    
    def __repr__(self):
        s = "Alignment(\n"
        s += f"    ref='{self.reference_aligned}',\n"
        s += f"    hyp='{self.prediction_aligned}'\n"
        s += f"    wer={self.wer:.2f},\n"
        s += f"    cer={self.cer:.2f}\n"
        s += ")"
        return s

In [None]:
reference = "hej digder karl smart"
prediction = "hejdig der carl smat"

alignment = Alignment(reference, prediction)
alignment.print_raw()
alignment.print_aligned_colored()
alignment.report_errors();

for word_alignment in alignment.word_alignments:
    print(word_alignment)

alignment.num_spaces, alignment.num_space_substitutions, alignment.num_space_insertions, alignment.num_space_deletions

In [None]:
alignment.word_alignments[1].ref_word

In [None]:
reference = "hejsa du der"
prediction = "hej sa duder"

alignment = Alignment(reference, prediction)
alignment.print_raw()
alignment.print_aligned_colored()
alignment.report_errors();

for word_alignment in alignment.word_alignments:
    print(word_alignment)

alignment.num_spaces, alignment.num_space_substitutions, alignment.num_space_insertions, alignment.num_space_deletions

In [None]:
example_id = 20

reference = librispeech_test_clean[example_id]["text"]
prediction = librispeech_test_clean[example_id]["transcript"]
mc_prediction = mc_transcripts[0][example_id]

print("Gold VS Prediction")
alignment = Alignment(reference, prediction)
alignment.print_aligned_colored()

print("Prediction VS MC Prediction")
alignment = Alignment(prediction, mc_prediction)
alignment.print_aligned_colored()

In [None]:
alignment, alignment.word_alignments[0], alignment.word_alignments[-4], alignment.word_alignments[-7]

In [None]:
example_id = 20

reference = librispeech_test_clean[example_id]["text"]
prediction = librispeech_test_clean[example_id]["transcript"]
mc_prediction = mc_transcripts[0][example_id]

print("Gold VS Prediction")
alignment = CharAlignment(reference, prediction)
alignment.visualize(include_cer=True)

print("Prediction VS MC Prediction")
alignment = CharAlignment(prediction, mc_prediction)
alignment.visualize(include_cer=True)

In [None]:
example_id = 20

reference = librispeech_test_clean[example_id]["text"]
prediction = librispeech_test_clean[example_id]["transcript"]
mc_prediction = mc_transcripts[0][example_id]

print("Gold VS Prediction")
alignment = CharAlignment(reference, prediction)
alignment.visualize(include_cer=True)

print("Prediction VS MC Prediction")
alignment = CharAlignment(prediction, mc_prediction)
alignment.visualize(include_cer=True)

In [None]:
example_id = 20

reference = librispeech_test_clean[example_id]["text"]
prediction = librispeech_test_clean[example_id]["transcript"]
mc_prediction = mc_transcripts[0][example_id]

print("Gold VS Prediction")
alignment = WordAlignment(reference, prediction)
alignment.visualize(include_cer=True)
alignment = CharAlignment(reference, prediction)
alignment.visualize(include_cer=True)

print("Prediction VS MC Prediction")
alignment = WordAlignment(prediction, mc_prediction)
alignment.visualize(include_cer=True)
alignment = CharAlignment(prediction, mc_prediction)
alignment.visualize(include_cer=True)

In [None]:
# Compute the alignment to the prediction for all Monte Carlo transcripts
alignments_pred_to_gold = []
alignments_mc_to_pred = []

mc_transcripts_transposed = list(zip(*mc_transcripts))
for i in tqdm(range(len(librispeech_test_clean)), desc="Computing alignments"):
    reference = librispeech_test_clean[i]["text"]
    prediction = librispeech_test_clean[i]["transcript"]

    alignments_pred_to_gold.append(Alignment(reference, prediction))
    alignments_mc_to_pred.append([Alignment(prediction, prediction_mc) for prediction_mc in mc_transcripts_transposed[i]])

In [None]:
# Obtain the list of CERs over Monte Carlo transcripts for each word in the reference transcript

all_errors = [defaultdict(list) for _ in range(len(librispeech_test_clean))]
misspellings = [defaultdict(list) for _ in range(len(librispeech_test_clean))]
fraction_misspellings = [defaultdict(int) for _ in range(len(librispeech_test_clean))]
for i in tqdm(range(len(librispeech_test_clean)), desc="Counting misspellings and omissions"):

    # idx2word = {i: e.ref_word for i, e in enumerate(alignments_mc_to_pred_i[0].alignment)}
    
    # Get the alignments
    alignments_mc_to_pred_i = alignments_mc_to_pred[i]  # Get the alignments for all Monte Carlo transcripts
    for alignment_mc_to_pred in alignments_mc_to_pred_i:

        # Keep track of how many times each reference word has been seen in a single alignment.
        # Some words might be used multiple times and we must discern between them.
        word_count = defaultdict(int)

        # Iterate over the alignment and save the CER for each word
        for element in alignment_mc_to_pred.alignment:
            word_count[element.ref_word] += 1
            
            k = element.ref_word + "_" * (word_count[element.ref_word] - 1)
            all_errors[i][k].append(element.cer)

            if element.type != "EXTRA":
                misspellings[i][k].append(element.cer)
                fraction_misspellings[i][k] += 1 / NUM_SAMPLES
                
print("Num CERs in gold: ", len([e.cer for a in alignments_pred_to_gold for e in a.alignment if e.type != "OMITTED"]))
print("Num CERs in pred: ", len([cer for error_cer in misspellings for cer in error_cer.values()]))

In [None]:
median_misspellings = [{k: np.median(v) for k, v in misspelling.items()} for misspelling in misspellings]
mean_misspellings = [{k: np.mean(v) for k, v in misspelling.items()} for misspelling in misspellings]

In [None]:
example_id = 59
alignments_pred_to_gold[example_id].visualize(include_cer=True, other_numeric_scores={"Median CER": median_misspellings[example_id].values()})

In [None]:
cers_from_prediction = np.array([e.cer for a in alignments_pred_to_gold for e in a.alignment if e.type != "OMITTED"])
cers_from_montecarlo = np.array([cer for misspelling_cer in median_misspellings for cer in misspelling_cer.values()])
# cers_from_montecarlo = np.array([cer for misspelling_cer in mean_misspellings for cer in misspelling_cer.values()])
cers_from_prediction.shape, cers_from_montecarlo.shape

In [None]:
fig, ax = plt.subplots(figsize=(6.4, 4.8))
ax.scatter(cers_from_prediction, cers_from_montecarlo, alpha=0.1)

ax.text(0.76, 0.95, f"Pearson: {scipy.stats.pearsonr(cers_from_prediction, cers_from_montecarlo)[0]:.3f}", transform=ax.transAxes)
ax.text(0.73, 0.90, f"Spearman: {scipy.stats.spearmanr(cers_from_prediction, cers_from_montecarlo)[0]:.3f}", transform=ax.transAxes)

ax.set_xlabel("CER of standard transcript cf. target")
ax.set_ylabel("Median CER of MC transcripts cf. standard transcript")
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
fig.savefig("mc_transcript_cer_scatter.pdf", bbox_inches="tight")

In [None]:
# Above but heatmap
from matplotlib.colors import LogNorm
import seaborn as sns

data = np.stack([cers_from_prediction, cers_from_montecarlo], axis=0)
kde = scipy.stats.gaussian_kde(data, bw_method=0.3)

xmin, xmax = 0, 1
ymin, ymax = 0, 1

# log_norm = LogNorm(vmin=data.min().min(), vmax=data.max().max())
xx, yy = np.meshgrid(np.linspace(xmin, xmax, 100), np.linspace(ymin, ymax, 100))
positions = np.vstack([xx.ravel(), yy.ravel()])
# values = np.vstack([x, y])

f = np.log(np.reshape(kde(positions).T, xx.shape))

fig = plt.figure()
ax = fig.gca()

ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)

ax.contourf(xx, yy, f, 15, cmap='Blues', label="Log-density")
cset = ax.contour(xx, yy, f, 15, colors='k', linewidths=0.5)

ax.scatter(cers_from_prediction, cers_from_montecarlo, alpha=0.1, marker=".", s=1, color="black", label="Data")

ax.clabel(cset, inline=1, fontsize=10)
ax.set_xlabel("CER from Prediction")
ax.set_ylabel("CER from Monte Carlo Dropout")
# ax.legend()

In [None]:
# Is there a correlation (per word) between the fraction of MC transcripts that are wrong compared to the standard transcript and whether the standard transcript was correct compared to the target?
# Is there a correlation (per word) between the fraction of MC transcripts that are wrong compared to the standard transcript and the CER of the standard transcript compared to the target?


# Uncertainty in representations

In [None]:
librispeech_test_clean = load_dataset("librispeech_asr", "clean", split="test")

In [None]:
def compute_length(example):
    example["length"] = len(example["audio"]["array"])
    return example

librispeech_test_clean = librispeech_test_clean.map(compute_length)

In [None]:
# processor = Wav2Vec2Processor.from_pretrained(WAV2VEC_LARGE_PRETRAINED)
wav2vec_config = Wav2Vec2Config.from_pretrained(WAV2VEC_LARGE_PRETRAINED)
w2v2_feature_extractor_pretrained = Wav2Vec2FeatureExtractor.from_pretrained(WAV2VEC_LARGE_PRETRAINED)
wav2vec_large_pretrained = Wav2Vec2Model.from_pretrained(WAV2VEC_LARGE_PRETRAINED).to(device)

In [None]:
wav2vec_config

In [None]:
collate_fn_w2v2 = functools.partial(collate_fn, feature_extractor=w2v2_feature_extractor_pretrained)
dataloader = DataLoader(librispeech_test_clean, batch_size=1, collate_fn=collate_fn_w2v2, num_workers=4)
iterator = iter(dataloader)
batch = next(iterator)
batch

In [None]:
input_values = batch["input_values"].to("cuda")
attention_mask = batch["attention_mask"].to("cuda")

with torch.no_grad():
    features = wav2vec_large_pretrained(input_values, attention_mask=attention_mask, output_hidden_states=True)

In [None]:
features.keys()

In [None]:
features.last_hidden_state.shape

In [None]:
features.extract_features.shape

In [None]:
len(features.hidden_states), features.hidden_states[0].shape, features.hidden_states[-1].shape

In [None]:
@torch.autocast(device_type="cuda")
@torch.inference_mode()
def extract_features_batch(batch, model):
    input_values = batch["input_values"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    features = model(input_values, attention_mask=attention_mask, output_hidden_states=True)
    
    lens = batch["lengths"] // 320  # 320 is the model hop length
    
    # remove padding from features
    features_18 = features.hidden_states[17].cpu().unbind(0)
    features_18 = [features_18[i][:lens[i], :] for i in range(len(lens))]
    
    features_15 = features.hidden_states[14].cpu().unbind(0)
    features_15 = [features_15[i][:lens[i], :] for i in range(len(lens))]
    
    features = {
        15: features_15,
        18: features_18,
    }
    
    return batch["ids"], features

@torch.autocast(device_type="cuda")
@torch.inference_mode()
def extract_features_dataset(dataloader, model):
    ids = []
    features = defaultdict(list)
    for batch in tqdm(dataloader, desc="Extracting features"):
        id, feat = extract_features_batch(batch, model)
        ids.extend(id)
        for k, v in feat.items():
            features[k].extend(v)

    return ids, features

In [None]:
LIBRISPEECH_TEST_CLEAN_WITH_W2V2_FEATURES_DIR = os.path.join(TEMP_SAVE_DIR, "librispeech_test_clean_with_w2v2_features")
os.makedirs(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_FEATURES_DIR, exist_ok=True)

try:
    features = torch.load(os.path.join(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_FEATURES_DIR, "features.pt"), map_location="cpu")
except (AssertionError, FileNotFoundError):
    ids, features = extract_features_dataset(dataloader, wav2vec_large_pretrained)
    torch.save(features, os.path.join(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_FEATURES_DIR, "features.pt"))