wav2vec 2.0
- `wav2vec2-large` pretrained: https://huggingface.co/facebook/wav2vec2-large-lv60
- `wav2vec2-large` fine-tuned: https://huggingface.co/facebook/wav2vec2-large-960h-lv60

Whisper
- `whisper-large-v2`: https://huggingface.co/openai/whisper-large-v2

In [None]:
WAV2VEC_LARGE_PRETRAINED = "facebook/wav2vec2-large-lv60"
WAV2VEC_LARGE_FINETUNED = "facebook/wav2vec2-large-960h-lv60"
WHISPER_LARGE_V2 = "openai/whisper-large-v2"

In [None]:
HUGGINFACE_HOME = "/m2/research/huggingface"
TEMP_SAVE_DIR = "/m2/research/jdh/thesis/"

In [None]:
import os

os.environ["HF_HOME"] = HUGGINFACE_HOME

In [None]:
import torch
import functools

from datasets import load_dataset, load_from_disk
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor, Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Config
from torch.utils.data import DataLoader
from jiwer import wer
from tqdm import tqdm

In [None]:
torch.cuda.is_available(), torch.cuda.device_count()

In [None]:
device = torch.device('cuda:0')

torch.set_default_device(torch.device('cuda:0'))

# Uncertainty in transcripts

In [None]:
librispeech_test_clean = load_dataset("librispeech_asr", "clean", split="test")

In [None]:
def compute_length(example):
    example["length"] = len(example["audio"]["array"])
    return example

In [None]:
librispeech_test_clean = librispeech_test_clean.map(compute_length)

In [None]:
w2v2_model = Wav2Vec2ForCTC.from_pretrained(WAV2VEC_LARGE_FINETUNED).to(device)
w2v2_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(WAV2VEC_LARGE_FINETUNED)
w2v2_tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(WAV2VEC_LARGE_FINETUNED)
# processor = Wav2Vec2Processor.from_pretrained(WAV2VEC_LARGE_FINETUNED)

In [None]:
librispeech_test_clean[:2]

In [None]:
def collate_fn(batch, tokenizer, feature_extractor):
    files = [b["file"] for b in batch]
    audios = [b["audio"]["array"] for b in batch]
    texts = [b["text"] for b in batch]
    speaker_ids = [b["speaker_id"] for b in batch]
    chapter_ids = [b["chapter_id"] for b in batch]
    overall_ids = [b["id"] for b in batch]
    
    labels = tokenizer(texts, return_tensors="np", padding="longest", return_attention_mask=False).input_ids
    features = feature_extractor(audios, sampling_rate=16_000, padding="longest", return_tensors="np", return_attention_mask=True, )
    audios, attention_mask = features.input_values, features.attention_mask
    
    labels = torch.from_numpy(labels)
    audios = torch.from_numpy(audios)
    attention_mask = torch.from_numpy(attention_mask)
    
    batch = {
        "input_values": audios,
        "attention_mask": attention_mask,
        "labels": labels,
        "speaker_ids": speaker_ids,
        "chapter_ids": chapter_ids,
        "ids": overall_ids,
        "files": files,
        "texts": texts
    }    
    return batch

In [None]:
collate_fn_w2v2 = functools.partial(collate_fn, tokenizer=w2v2_tokenizer, feature_extractor=w2v2_feature_extractor)
dataloader = DataLoader(librispeech_test_clean, batch_size=8, collate_fn=collate_fn_w2v2, num_workers=4)
iterator = iter(dataloader)
next(iterator)

In [None]:
@torch.autocast(device_type="cuda")
@torch.inference_mode()
def transcribe_batch(batch, model, tokenizer):
    input_values = batch["input_values"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    logits = model(input_values, attention_mask=attention_mask).logits
    pred_ids = torch.argmax(logits, dim=-1)
    pred_str = tokenizer.batch_decode(pred_ids)
    return pred_str

In [None]:
def transcribe_dataset(dataloader, model, tokenizer):
    transcripts = []
    for batch in tqdm(dataloader):
        pred_str = transcribe_batch(batch, model, tokenizer)
        transcripts.extend(pred_str)
    return transcripts

In [None]:
LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR = os.path.join(TEMP_SAVE_DIR, "librispeech_test_clean_with_w2v2_transcripts")
try:
    librispeech_test_clean = load_from_disk(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR)
    assert "transcript" in librispeech_test_clean.column_names
except (AssertionError, FileNotFoundError):
    transcripts = transcribe_dataset(dataloader, w2v2_model, w2v2_tokenizer)
    librispeech_test_clean = librispeech_test_clean.add_column("transcript", transcripts)
    librispeech_test_clean.save_to_disk(LIBRISPEECH_TEST_CLEAN_WITH_W2V2_TRANSCRIPTS_DIR)

In [None]:
def compute_mc_transcripts(dataloader, model, tokenizer, num_samples):
    model.train()
    mc_transcripts = []
    for i in tqdm(range(num_samples)):
        out_dir = os.path.join(TEMP_SAVE_DIR, f"librispeech_test_clean_with_w2v2_transcripts_mc_{i:03d}")
        try:
            mc_dataset = load_from_disk(out_dir)
            transcripts = mc_dataset["mc_transcript"]
        except:
            transcripts = transcribe_dataset(dataloader, model, tokenizer)
            mc_dataset = librispeech_test_clean.add_column("mc_transcript", transcripts)
            mc_dataset.save_to_disk(out_dir)

        mc_transcripts.append(transcripts)

    return transcripts

In [None]:
# Compute and add Monte Carlo Dropout simulations
mc_transcripts = compute_mc_transcripts(dataloader, w2v2_model, w2v2_tokenizer, 2)

In [None]:
librispeech_test_clean["transcript"]

In [None]:
raise Exception("STOP HERE")

In [None]:
def map_to_pred(batch):
    print(batch)
    audios = [a["array"] for a in batch["audio"]]
    sampling_rate = batch["audio"][0]["sampling_rate"]
    inputs = processor(audios, sampling_rate=sampling_rate, return_tensors="pt", padding="longest")
    input_values = inputs.input_values.to(device)  # .to(torch.float16)
    attention_mask = inputs.attention_mask.to(device)  # .to(torch.float16)
    
    with torch.inference_mode():
        logits = model(input_values, attention_mask=attention_mask).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    batch["transcription"] = transcription
    return batch


try:
    librispeech_test_clean = load_from_disk(TEMP_SAVE_DIR + "librispeech_test_clean")
    assert "transcription" in librispeech_test_clean
except:
    result = librispeech_test_clean.map(map_to_pred, batched=True, batch_size=8)#, remove_columns=["speech"])
    librispeech_test_clean.save_to_disk(TEMP_SAVE_DIR + "librispeech_test_clean")

print("WER:", wer(result["text"], result["transcription"]))

In [None]:
dataloader = torch.utils.data.DataLoader(librispeech_test_clean, batch_size=8, shuffle=False, collate_fn=lambda x: x, drop_last=False, num_workers=8)

In [None]:
it = iter(dataloader)

In [None]:
next(it)

# Uncertainty in representations

In [None]:
librispeech_test_clean = load_dataset("librispeech_asr", "clean", split="test")

In [None]:
librispeech_test_clean.with_format("torch", device=device)
librispeech_test_clean.

In [None]:
# processor = Wav2Vec2Processor.from_pretrained(WAV2VEC_LARGE_PRETRAINED)
wav2vec_config = Wav2Vec2Config.from_pretrained(WAV2VEC_LARGE_PRETRAINED)
wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(WAV2VEC_LARGE_PRETRAINED)
wav2vec_large_pretrained = Wav2Vec2Model.from_pretrained(WAV2VEC_LARGE_PRETRAINED).to(device)

In [None]:
wav2vec_config

In [None]:
def list_of_dicts_to_dict_of_lists(list_of_dicts):
    return {k: [d[k] for d in list_of_dicts] for k in list_of_dicts[0]}

In [None]:
batch = librispeech_test_clean[0:2]

In [None]:
batch["file"]

In [None]:
inputs = wav2vec_feature_extractor(batch["audio"]["array"], sample_rate=16000, return_tensors="pt", padding="longest")

In [None]:
input_values = inputs.input_values.to("cuda")
attention_mask = inputs.attention_mask.to("cuda")

with torch.no_grad():
    features = wav2vec_large_pretrained(input_values, attention_mask=attention_mask, output_hidden_states=True)


In [None]:
features

In [None]:
def map_to_wav2vec2_features(batch):
    print(batch)
    inputs = wav2vec_feature_extractor(batch["audio"]["array"], return_tensors="pt", padding="longest")
    input_values = inputs.input_values.to("cuda")
    attention_mask = inputs.attention_mask.to("cuda")

    with torch.no_grad():
        features = wav2vec_large_pretrained(input_values, attention_mask=attention_mask, output_hidden_states=True)
        
        batch["features-15"] = features.hidden_states[15]

    return batch



In [None]:
librispeech_test_clean.map(map_to_wav2vec2_features, batched=True, batch_size=8)