In [1]:
from transformers import Wav2Vec2Processor, Wav2Vec2ConformerForCTC
from datasets import load_dataset
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load model and processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-conformer-rope-large-960h-ft")
model = Wav2Vec2ConformerForCTC.from_pretrained("facebook/wav2vec2-conformer-rope-large-960h-ft")

model.to(device)
    
# load dummy dataset and read soundfiles
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")

raw_audio = ds[0]["audio"]["array"]
raw_audio = torch.tensor(raw_audio)

# tokenize
input_values = processor(ds[0]["audio"]["array"], return_tensors="pt", padding="longest").input_values

# retrieve logits
logits = model(input_values.to(device)).logits

# take argmax and decode
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)

  from .autonotebook import tqdm as notebook_tqdm
It is strongly recommended to pass the `sampling_rate` argument to `Wav2Vec2FeatureExtractor()`. Failing to do so can result in silent errors that might be hard to debug.


In [2]:
print(len(ds[0]["audio"]["array"]))
print(raw_audio.dtype)
print(input_values.shape)
print(transcription) 

93680
torch.float64
torch.Size([1, 93680])
['MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL']


In [9]:
import shap
import numpy.ma as ma
import numpy as np

class ModelWrapper(torch.nn.Module):
    def __init__(self, model, processor):
        super(ModelWrapper, self).__init__()
        self.model = model
        self.processor = processor

    def forward(self, x):
        x = self.processor(x, return_tensors="pt", padding="longest").input_values.to(device)
        logits = self.model(x).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        return predicted_ids.cpu().numpy()

def masker(x, mask):
    x = ma.array(x, mask = mask*-1)
    x = x.filled(0.0)
    return x

# explain the model on two sample inputs
wrapped_model = ModelWrapper(model, processor)
explainer = shap.GradientExplainer(model=wrapped_model, data=[raw_audio])
shap_values = explainer.shap_values([raw_audio])


It is strongly recommended to pass the `sampling_rate` argument to `Wav2Vec2FeatureExtractor()`. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the `sampling_rate` argument to `Wav2Vec2FeatureExtractor()`. Failing to do so can result in silent errors that might be hard to debug.


RuntimeError: Calculated padded input size per channel: (1). Kernel size: (3). Kernel size can't be greater than actual input size