wav2vec 2.0
- `wav2vec2-large` pretrained: https://huggingface.co/facebook/wav2vec2-large-lv60
- `wav2vec2-large` fine-tuned: https://huggingface.co/facebook/wav2vec2-large-960h-lv60

Whisper
- `whisper-large-v2`: https://huggingface.co/openai/whisper-large-v2

To precompute transcripts and other expensive things, run the notebook via ipython from the terminal:

```bash
ipython --to python --convert "ssl_uncertainty.ipynb"

```

In [None]:
WAV2VEC_LARGE_PRETRAINED = "facebook/wav2vec2-large-lv60"
WAV2VEC_LARGE_FINETUNED = "facebook/wav2vec2-large-960h-lv60"
WHISPER_LARGE_V2 = "openai/whisper-large-v2"

In [None]:
HUGGINFACE_HOME = "/m2/research/huggingface"
TEMP_SAVE_DIR = "/m2/research/jdh/thesis/"

In [None]:
import os

os.environ["HF_HOME"] = HUGGINFACE_HOME
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import functools

from typing import List, Dict
from copy import deepcopy
from collections import defaultdict

import torch
import numba
import numpy as np
import rich
import matplotlib.pyplot as plt
import scipy
import pandas as pd

from datasets import load_dataset, load_from_disk
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor, Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Config
from torch.utils.data import DataLoader
from jiwer import wer, cer
from tqdm import tqdm
from IPython.display import HTML, display

In [None]:
print(torch.cuda.is_available(), torch.cuda.device_count())

device = torch.device('cuda:0')

torch.set_default_device(device)

# Uncertainty in transcripts

In [None]:
librispeech_test_clean = load_dataset("librispeech_asr", "clean", split="test")

In [None]:
w2v2_model = Wav2Vec2ForCTC.from_pretrained(WAV2VEC_LARGE_FINETUNED).to(device)
w2v2_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(WAV2VEC_LARGE_FINETUNED)
w2v2_tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(WAV2VEC_LARGE_FINETUNED)
# processor = Wav2Vec2Processor.from_pretrained(WAV2VEC_LARGE_FINETUNED)

In [None]:
librispeech_test_clean[:2]

In [None]:
def collate_fn(batch, tokenizer, feature_extractor):
    files = [b["file"] for b in batch]
    audios = [b["audio"]["array"] for b in batch]
    texts = [b["text"] for b in batch]
    speaker_ids = [b["speaker_id"] for b in batch]
    chapter_ids = [b["chapter_id"] for b in batch]
    overall_ids = [b["id"] for b in batch]
    lengths = [b["length"] for b in batch]
    
    labels = tokenizer(texts, return_tensors="np", padding="longest", return_attention_mask=False).input_ids
    features = feature_extractor(audios, sampling_rate=16_000, padding="longest", return_tensors="np", return_attention_mask=True, )
    audios, attention_mask = features.input_values, features.attention_mask
    
    labels = torch.from_numpy(labels)
    audios = torch.from_numpy(audios)
    attention_mask = torch.from_numpy(attention_mask)
    
    batch = {
        "input_values": audios,
        "attention_mask": attention_mask,
        "labels": labels,
        "speaker_ids": speaker_ids,
        "chapter_ids": chapter_ids,
        "ids": overall_ids,
        "files": files,
        "texts": texts,
        "lengths": torch.as_tensor(lengths, device="cpu"),
    }    
    return batch

In [None]:
def compute_length(example):
    example["length"] = len(example["audio"]["array"])
    return example

librispeech_test_clean = librispeech_test_clean.map(compute_length)

In [None]:
collate_fn_w2v2 = functools.partial(collate_fn, tokenizer=w2v2_tokenizer, feature_extractor=w2v2_feature_extractor)
dataloader = DataLoader(librispeech_test_clean, batch_size=1, collate_fn=collate_fn_w2v2, num_workers=4)
iterator = iter(dataloader)
next(iterator)

In [None]:
@torch.autocast(device_type="cuda")
@torch.inference_mode()
def transcribe_batch(batch, model, tokenizer):
    input_values = batch["input_values"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    logits = model(input_values, attention_mask=attention_mask).logits
    
    pred_ids = torch.argmax(logits, dim=-1)
    transcripts = tokenizer.batch_decode(pred_ids)
    
    # remove padding from logits
    lens = batch["lengths"] // 320  # 320 is the model hop length
    logits = logits.unbind(0)
    logits = [logits[i][:lens[i], :] for i in range(len(lens))]
    return batch["ids"], transcripts, logits

In [None]:
@torch.inference_mode()
def transcribe_dataset(dataloader, model, tokenizer):
    ids = []
    transcripts = []
    logits = []
    for batch in tqdm(dataloader, desc="Transcribing"):
        id, transcript, logit = transcribe_batch(batch, model, tokenizer)
        transcripts.extend(transcript)
        logits.extend(logit)
        ids.extend(id)

    return ids, transcripts, logits

In [None]:
# Transcribe the dataset and save the results (load if exists already)

LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR = os.path.join(TEMP_SAVE_DIR, "librispeech_test_clean_with_w2v2_transcripts")

try:
    librispeech_test_clean = load_from_disk(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR)
    assert "transcript" in librispeech_test_clean.column_names
    logits = torch.load(os.path.join(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR, "logits.pt"))
except (AssertionError, FileNotFoundError):
    ids, transcripts, logits = transcribe_dataset(dataloader, w2v2_model, w2v2_tokenizer)
    librispeech_test_clean = librispeech_test_clean.add_column("transcript", transcripts)
    librispeech_test_clean.save_to_disk(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR)
    torch.save(logits, os.path.join(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR, "logits.pt"))

In [None]:
def compute_mc_transcripts(dataloader, model, tokenizer, num_samples):
    """Compute Monte Carlo transcripts for a dataset `num_samples` times."""
    model.train()  # set to train mode to enable dropout
    mc_transcripts = []
    mc_logits = []

    bar = tqdm(range(num_samples), desc="Getting MC Transcripts")
    for i in bar:
        out_dir = os.path.join(TEMP_SAVE_DIR, f"librispeech_test_clean_with_w2v2_transcripts_mc_{i:03d}")
        bar.set_description(f"Getting MC Transcripts {i:03d}")
        try:
            mc_dataset = load_from_disk(out_dir)
            transcripts = mc_dataset["mc_transcript"]
            logits = torch.load(os.path.join(out_dir, "mc_logits.pt"))
        except:
            transcripts, logits = transcribe_dataset(dataloader, model, tokenizer)
            mc_dataset = librispeech_test_clean.add_column("mc_transcript", transcripts)
            mc_dataset.save_to_disk(out_dir)
            torch.save(logits, os.path.join(out_dir, "mc_logits.pt"))

        mc_transcripts.append(transcripts)
        mc_logits.append(logits)

    return mc_transcripts

In [None]:
def compute_mc_transcripts(dataloader, model, tokenizer, num_samples):
    """Compute Monte Carlo transcripts for a dataset `num_samples` times."""
    model.train()  # set to train mode to enable dropout
    
    mc_ids = []
    mc_transcripts = []
    mc_logits = []
    
    # bar = tqdm(reversed(range(num_samples)), desc="Getting MC Transcripts")
    bar = tqdm(range(num_samples), desc="Getting MC Transcripts")
    for i in bar:
        bar.set_description(f"Getting MC Transcripts {i:03d}")

        try:
            ids = torch.load(os.path.join(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR, f"mc_ids_{i:03d}.pt"))
            transcripts = torch.load(os.path.join(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR, f"mc_transcript_{i:03d}.pt"))
            logits = torch.load(os.path.join(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR, f"mc_logits_{i:03d}.pt"))
        except FileNotFoundError as exc:
            print(exc)
            ids, transcripts, logits = transcribe_dataset(dataloader, model, tokenizer)
            torch.save(ids, os.path.join(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR, f"mc_ids_{i:03d}.pt"))
            torch.save(transcripts, os.path.join(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR, f"mc_transcript_{i:03d}.pt"))
            torch.save(logits, os.path.join(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR, f"mc_logits_{i:03d}.pt"))

        mc_ids.append(ids)
        mc_transcripts.append(transcripts)
        mc_logits.append(logits)

    return mc_ids, mc_transcripts, mc_logits

In [None]:
# Compute and add Monte Carlo Dropout simulations
NUM_SAMPLES = 256

mc_ids, mc_transcripts, mc_logits = compute_mc_transcripts(dataloader, w2v2_model, w2v2_tokenizer, num_samples=NUM_SAMPLES)


In [None]:
# raise Exception("Stop here")

In [None]:
class AlignmentElement:
    """Class to store an element of an alignment between a reference and a Monte Carlo transcript."""
    def __init__(self, type_, hyp_word, ref_word):
        self.type = type_
        self.hyp_word = hyp_word
        self.ref_word = ref_word
        self.cer = cer(hyp_word, ref_word) if hyp_word != "" and ref_word != "" else 1.0  
    
    def __repr__(self):
        return f"AlignmentElement(type_='{self.type}', hyp_word='{self.hyp_word}', ref_word='{self.ref_word}')"


class Alignment:
    """Class to store an alignment between a reference and a Monte Carlo transcript."""
    def __init__(self, reference, prediction):
        self.reference = reference
        self.prediction = prediction
        if len(prediction) == 0:
            self.alignment = [AlignmentElement("OMITTED", "", word) for word in reference.split()]
            self.wer = 1.0
            self.cer = 1.0
        else:    
            self.alignment = align_transcripts(reference, prediction)
            self.wer = wer(reference, prediction)
            self.cer = cer(reference, prediction)

    def visualize(self, include_cer=False, other_numeric_scores: Dict[str, List[float]] = None):
        visualize_alignment(self.alignment, include_cer=include_cer, other_numeric_scores=other_numeric_scores)
        
    def __repr__(self):
        return f"Alignment(\n\treference='{self.reference}',\n\tprediction='{self.prediction}'\n)"


def levenstein_distance_nojit(reference: List[str], prediction: List[str]) -> List[List[int]]:
    """Compute the Levenstein distance between two tokenized strings."""
    
    # Create a matrix to store alignment costs
    alignment_matrix = np.zeros((len(reference) + 1, len(prediction) + 1))
    
    # Initialize the matrix with deletion costs for reference words
    for i in range(len(reference) + 1):
        alignment_matrix[i][0] = i
    
    # Initialize the matrix with insertion costs for Monte Carlo words
    for j in range(len(prediction) + 1):
        alignment_matrix[0][j] = j
    
    # Fill in the alignment matrix
    for i in range(1, len(reference) + 1):
        for j in range(1, len(prediction) + 1):
            cost = 0 if reference[i - 1] == prediction[j - 1] else 1
            alignment_matrix[i][j] = min(
                alignment_matrix[i - 1][j] + 1,        # Deletion
                alignment_matrix[i][j - 1] + 1,        # Insertion
                alignment_matrix[i - 1][j - 1] + cost  # Substitution
            )
    
    return alignment_matrix



@numba.jit(nopython=True)
def levenstein_distance(reference: List[str], prediction: List[str]) -> np.ndarray:
    """Compute the Levenstein distance between two tokenized strings."""
    
    # Create a matrix to store alignment costs
    alignment_matrix = np.zeros((len(reference) + 1, len(prediction) + 1))
    
    # Initialize the matrix with deletion costs for reference words
    alignment_matrix[:, 0] = np.arange(len(reference) + 1)
    
    # Initialize the matrix with insertion costs for Monte Carlo words
    alignment_matrix[0, :] = np.arange(len(prediction) + 1)
    
    # Fill in the alignment matrix
    for i in range(1, len(reference) + 1):
        for j in range(1, len(prediction) + 1):
            cost = 0 if reference[i - 1] == prediction[j - 1] else 1
            alignment_matrix[i][j] = min(
                alignment_matrix[i - 1][j] + 1,        # Deletion
                alignment_matrix[i][j - 1] + 1,        # Insertion
                alignment_matrix[i - 1][j - 1] + cost  # Substitution
            )
    
    return alignment_matrix
    

def align_transcripts(reference: str, prediction: str) -> List[AlignmentElement]:
    reference_words = reference.split()
    prediction_words = prediction.split()
    
    alignment_matrix = levenstein_distance(reference.split(), prediction.split())
    
    # Backtrace to find the alignment
    i, j = len(reference_words), len(prediction_words)
    alignment = []
    
    while i > 0 or j > 0:
        if i > 0 and alignment_matrix[i][j] == alignment_matrix[i - 1][j] + 1:  # Deletion
            alignment.insert(0, AlignmentElement("OMITTED", "", reference_words[i - 1]))
            i -= 1
        elif j > 0 and alignment_matrix[i][j] == alignment_matrix[i][j - 1] + 1:  # Insertion
            alignment.insert(0, AlignmentElement("EXTRA", prediction_words[j - 1], ""))
            j -= 1
        else:
            if reference_words[i - 1] != prediction_words[j - 1]:  # Substitution
                alignment.insert(0, AlignmentElement("MISSPELLED", prediction_words[j - 1], reference_words[i - 1]))
            else:  # Match
                alignment.insert(0, AlignmentElement("MATCHED", prediction_words[j - 1], reference_words[i - 1]))
            i -= 1
            j -= 1
    
    return alignment


def visualize_alignment(alignment: List[AlignmentElement], include_cer=False, other_numeric_scores: Dict[str, List[float]] = None):
    """Visualize the reference transcript and a potentially erroneous transcript given in `alignment`. 
    
    `alignment` is a list that contains tuples of (type, hyp_word, ref_word) where type is one of "MATCHED", 
    "MISSPELLED", "OMITTED", "EXTRA" and hyp_word is the word in the potentially erroneous transcript and ref_word 
    is the word in the reference transcript.
    
    The two transcripts are printed on top of each other with the words aligned.
    """
    
    row_ref = "<tr><td>Reference</td>"
    row_pre = "<tr><td>Prediction</td>"
    row_cer = "<tr><td>CER</td>"
    
    for element in alignment:
        if element.type == "MATCHED":
            row_ref += f"<td>{element.ref_word}</td>"
            row_pre += f"<td>{element.hyp_word}</td>"
        elif element.type == "MISSPELLED":
            row_ref += f"<td>{element.ref_word}</td>"
            row_pre += f"<td style='color:red'>{element.hyp_word}</td>"
        elif element.type == "OMITTED":
            row_ref += f"<td style='color:green'>{element.ref_word}</td>"
            row_pre += "<td>-</td>"
        elif element.type == "EXTRA":
            row_ref += "<td>-</td>"
            row_pre += f"<td style='color:orange'>{element.hyp_word}</td>"

        if include_cer:
            row_cer += f"<td>{element.cer:.2f}</td>"

    html = "<table>"
    html += row_ref + "</tr>"
    html += row_pre + "</tr>"

    if include_cer:
        html += row_cer + "</tr>"

    if other_numeric_scores is not None:
        for k, v in other_numeric_scores.items():
            html += f"<tr><td>{k}</td>"
            html += "".join([f"<td>{score:.2f}</td>" for score in v])
            html += "</tr>"

    html += "</table>"
    
    display(HTML(html))

In [None]:
# Example usage:
reference_transcript = "the quick brown fox jumps far"
predicted_transcript = "and the quck broown fox jumps"

alignment = Alignment(reference_transcript, predicted_transcript)

alignment.visualize(include_cer=True)

In [None]:
example_id = 15

reference = librispeech_test_clean[example_id]["text"]
prediction = librispeech_test_clean[example_id]["transcript"]
mc_prediction = mc_transcripts[0][example_id]

print("Gold VS Prediction")
alignment = Alignment(reference, prediction)
alignment.visualize(include_cer=True)

print("Prediction VS MC Prediction")
alignment = Alignment(prediction, mc_prediction)
alignment.visualize(include_cer=True)

In [None]:
# Compute the alignment to the prediction for all Monte Carlo transcripts
alignments_pred_to_gold = []
alignments_mc_to_pred = []

mc_transcripts_transposed = list(zip(*mc_transcripts))
for i in tqdm(range(len(librispeech_test_clean)), desc="Computing alignments"):
    reference = librispeech_test_clean[i]["text"]
    prediction = librispeech_test_clean[i]["transcript"]

    alignments_pred_to_gold.append(Alignment(reference, prediction))
    alignments_mc_to_pred.append([Alignment(prediction, prediction_mc) for prediction_mc in mc_transcripts_transposed[i]])

In [None]:
# Obtain the list of CERs over Monte Carlo transcripts for each word in the reference transcript

all_errors = [defaultdict(list) for _ in range(len(librispeech_test_clean))]
misspellings = [defaultdict(list) for _ in range(len(librispeech_test_clean))]
fraction_misspellings = [defaultdict(int) for _ in range(len(librispeech_test_clean))]
for i in tqdm(range(len(librispeech_test_clean)), desc="Counting misspellings and omissions"):

    # idx2word = {i: e.ref_word for i, e in enumerate(alignments_mc_to_pred_i[0].alignment)}
    
    # Get the alignments
    alignments_mc_to_pred_i = alignments_mc_to_pred[i]  # Get the alignments for all Monte Carlo transcripts
    for alignment_mc_to_pred in alignments_mc_to_pred_i:

        # Keep track of how many times each reference word has been seen in a single alignment.
        # Some words might be used multiple times and we must discern between them.
        word_count = defaultdict(int)

        # Iterate over the alignment and save the CER for each word
        for element in alignment_mc_to_pred.alignment:
            word_count[element.ref_word] += 1
            
            k = element.ref_word + "_" * (word_count[element.ref_word] - 1)
            all_errors[i][k].append(element.cer)

            if element.type != "EXTRA":
                misspellings[i][k].append(element.cer)
                fraction_misspellings[i][k] += 1 / NUM_SAMPLES
                
print("Num CERs in gold: ", len([e.cer for a in alignments_pred_to_gold for e in a.alignment if e.type != "OMITTED"]))
print("Num CERs in pred: ", len([cer for error_cer in misspellings for cer in error_cer.values()]))

In [None]:
median_misspellings = [{k: np.median(v) for k, v in misspelling.items()} for misspelling in misspellings]
mean_misspellings = [{k: np.mean(v) for k, v in misspelling.items()} for misspelling in misspellings]

In [None]:
example_id = 59
alignments_pred_to_gold[example_id].visualize(include_cer=True, other_numeric_scores={"Median CER": median_misspellings[example_id].values()})

In [None]:
cers_from_prediction = np.array([e.cer for a in alignments_pred_to_gold for e in a.alignment if e.type != "OMITTED"])
cers_from_montecarlo = np.array([cer for misspelling_cer in median_misspellings for cer in misspelling_cer.values()])
# cers_from_montecarlo = np.array([cer for misspelling_cer in mean_misspellings for cer in misspelling_cer.values()])
cers_from_prediction.shape, cers_from_montecarlo.shape

In [None]:
fig, ax = plt.subplots(figsize=(6.4, 4.8))
ax.scatter(cers_from_prediction, cers_from_montecarlo, alpha=0.1)

ax.text(0.76, 0.95, f"Pearson: {scipy.stats.pearsonr(cers_from_prediction, cers_from_montecarlo)[0]:.3f}", transform=ax.transAxes)
ax.text(0.73, 0.90, f"Spearman: {scipy.stats.spearmanr(cers_from_prediction, cers_from_montecarlo)[0]:.3f}", transform=ax.transAxes)

ax.set_xlabel("CER of standard transcript cf. target")
ax.set_ylabel("Median CER of MC transcripts cf. standard transcript")
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
fig.savefig("mc_transcript_cer_scatter.pdf", bbox_inches="tight")

In [None]:
# Above but heatmap
from matplotlib.colors import LogNorm
import seaborn as sns

data = np.stack([cers_from_prediction, cers_from_montecarlo], axis=0)
kde = scipy.stats.gaussian_kde(data, bw_method=0.3)

xmin, xmax = 0, 1
ymin, ymax = 0, 1

# log_norm = LogNorm(vmin=data.min().min(), vmax=data.max().max())
xx, yy = np.meshgrid(np.linspace(xmin, xmax, 100), np.linspace(ymin, ymax, 100))
positions = np.vstack([xx.ravel(), yy.ravel()])
# values = np.vstack([x, y])

f = np.log(np.reshape(kde(positions).T, xx.shape))

fig = plt.figure()
ax = fig.gca()

ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)

ax.contourf(xx, yy, f, 15, cmap='Blues', label="Log-density")
cset = ax.contour(xx, yy, f, 15, colors='k', linewidths=0.5)

ax.scatter(cers_from_prediction, cers_from_montecarlo, alpha=0.1, marker=".", s=1, color="black", label="Data")

ax.clabel(cset, inline=1, fontsize=10)
ax.set_xlabel("CER from Prediction")
ax.set_ylabel("CER from Monte Carlo Dropout")
# ax.legend()

In [None]:
# Is there a correlation (per word) between the fraction of MC transcripts that are wrong compared to the standard transcript and whether the standard transcript was correct compared to the target?
# Is there a correlation (per word) between the fraction of MC transcripts that are wrong compared to the standard transcript and the CER of the standard transcript compared to the target?


# Uncertainty in representations

In [None]:
librispeech_test_clean = load_dataset("librispeech_asr", "clean", split="test")

In [None]:
librispeech_test_clean.with_format("torch", device=device)
librispeech_test_clean.

In [None]:
# processor = Wav2Vec2Processor.from_pretrained(WAV2VEC_LARGE_PRETRAINED)
wav2vec_config = Wav2Vec2Config.from_pretrained(WAV2VEC_LARGE_PRETRAINED)
wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(WAV2VEC_LARGE_PRETRAINED)
wav2vec_large_pretrained = Wav2Vec2Model.from_pretrained(WAV2VEC_LARGE_PRETRAINED).to(device)

In [None]:
wav2vec_config

In [None]:
def list_of_dicts_to_dict_of_lists(list_of_dicts):
    return {k: [d[k] for d in list_of_dicts] for k in list_of_dicts[0]}

In [None]:
batch = librispeech_test_clean[0:2]

In [None]:
batch["file"]

In [None]:
inputs = wav2vec_feature_extractor(batch["audio"]["array"], sample_rate=16000, return_tensors="pt", padding="longest")

In [None]:
input_values = inputs.input_values.to("cuda")
attention_mask = inputs.attention_mask.to("cuda")

with torch.no_grad():
    features = wav2vec_large_pretrained(input_values, attention_mask=attention_mask, output_hidden_states=True)


In [None]:
features

In [None]:
def map_to_wav2vec2_features(batch):
    print(batch)
    inputs = wav2vec_feature_extractor(batch["audio"]["array"], return_tensors="pt", padding="longest")
    input_values = inputs.input_values.to("cuda")
    attention_mask = inputs.attention_mask.to("cuda")

    with torch.no_grad():
        features = wav2vec_large_pretrained(input_values, attention_mask=attention_mask, output_hidden_states=True)
        
        batch["features-15"] = features.hidden_states[15]

    return batch



In [None]:
librispeech_test_clean.map(map_to_wav2vec2_features, batched=True, batch_size=8)