In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import shap
import logging
import sys
import os
import soundfile as sf
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from datasets import load_dataset
import librosa
import librosa.display
from scipy.stats import pearsonr
from typing import List, Dict, Tuple
import json
from pathlib import Path
from tqdm import tqdm
import warnings

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('evaluation.log')
    ]
)
logger = logging.getLogger(__name__)

model_name = "facebook/wav2vec2-base-960h"
            

  from .autonotebook import tqdm as notebook_tqdm
2025-09-15 22:37:53.089618: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-09-15 22:37:53.275837: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1757968673.345518  119567 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1757968673.365556  119567 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1757968673.512104  119567 computation_placer.cc:177] computation placer already r

In [2]:
def _create_model_wrapper(model):
    """Create a wrapper class for the model to get properly shaped output"""
    class ModelWrapper(torch.nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model
        
        def forward(self, x):
            # Ensure input has correct shape for the model
            if len(x.shape) == 4:
                x = x.squeeze(1).squeeze(1)  # Remove extra dimensions
            elif len(x.shape) == 3:
                x = x.squeeze(1)  # Remove extra dimension
            
            # Add attention mask
            attention_mask = torch.ones_like(x)
            
            # Forward pass with attention mask
            logits = self.model(x, attention_mask=attention_mask).logits
            
            # Log the shape and statistics of logits
            logger.debug(f"Logits shape: {logits.shape}")
            logger.debug(f"Logits mean: {torch.mean(logits).item():.6f}")
            logger.debug(f"Logits std: {torch.std(logits).item():.6f}")
            
            # For SHAP, aggregate over vocab to get a scalar per time step
            return torch.max(logits,dim=-1).values  # [batch, seq_len]
    
    return ModelWrapper(model)

In [3]:
def _add_noise(audio: np.ndarray, snr_db: float) -> np.ndarray:
    """Add white noise to audio at specified SNR"""
    signal_power = np.mean(audio ** 2)
    noise_power = signal_power / (10 ** (snr_db / 10))
    noise = np.random.normal(0, np.sqrt(noise_power), len(audio))
    return audio + noise

In [4]:
def create_test_set(num_samples: int = 10) -> Dict:
    """Create a controlled test set with various conditions"""
    logger.info(f"Creating test set with {num_samples} samples")
    ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
    test_set = []
    dataset_index = 17

    for i in tqdm(range(min(num_samples, len(ds))), desc="Creating test samples"):
        sample = ds[i+dataset_index]
        audio = sample["audio"]["array"]
        while len(audio) < 100000:
            dataset_index += 1
            sample = ds[i+dataset_index]
            audio = sample["audio"]["array"]
        text = sample["text"]
        
        # Create clean sample
        test_set.append({
            "type": "clean",
            "audio": audio,
            "text": text,
            "snr": float('inf'),
            "noise": np.zeros_like(audio)
        })
        logger.info(f"Added clean sample {i+1}")
        
        # Create noisy samples with different SNRs [20, 10, 0, -5]
        for snr in tqdm([5,2,1], desc=f"Adding noise to sample {i+1}", leave=False):
            noisy_audio = _add_noise(audio, snr)
            test_set.append({
                "type": "noisy",
                "audio": noisy_audio,
                "text": text,
                "snr": snr,
                "noise": noisy_audio - audio
            })
            logger.info(f"Added noisy sample {i+1} with SNR {snr}dB")

    logger.info(f"Test set created with {len(test_set)} total samples")
    return test_set

In [5]:
def compute_shap_values(processor, device, wrapped_model, model, vocab, audio: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Compute SHAP values for an audio sample using GradientExplainer"""
    logger.info("Computing SHAP values")
    # Process audio
    inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
    input_values = inputs.input_values.to(device)
    
    # Ensure input_values has the correct shape [batch_size, sequence_length]
    if len(input_values.shape) == 3:
        input_values = input_values.squeeze(1)
    
    # Create background samples with correct shape [batch_size, sequence_length]
    num_background = 5
    background = torch.zeros((num_background, input_values.shape[1]), device=device)
    background += torch.randn_like(background) * 0.01  # Add small random noise
    logger.info(f"Created background samples with shape {background.shape}")
    logger.info(f"Background mean: {torch.mean(background).item():.6f}")
    logger.info(f"Background std: {torch.std(background).item():.6f}")
    
    # Initialize GradientExplainer with the model only
    explainer = shap.GradientExplainer(
        wrapped_model,
        background,
        batch_size=1
    )
    
    # Log model output before SHAP computation
    with torch.no_grad():
        model_output = wrapped_model(input_values)
        logger.info(f"Model output shape: {model_output.shape}")
        logger.info(f"Model output mean: {torch.mean(model_output).item():.6f}")
        logger.info(f"Model output std: {torch.std(model_output).item():.6f}")
        logger.info(f"Model output sum: {torch.sum(model_output).item():.6f}")
    
    # retrieve logits
    logits = model(input_values).logits

    # take argmax and decode
    predicted_ids = torch.argmax(logits, dim=-1)
    logger.info(f"Predicted IDs: {predicted_ids}")
    transcription = processor.batch_decode(predicted_ids)
    logger.info(f"Transcription: {transcription}")

    output_string = ''.join([list(vocab.keys())[list(vocab.values()).index(id.item())] for id in predicted_ids[0]])
    logger.info(f"Decoded output string: {output_string}")

    # Get SHAP values
    logger.info("Computing SHAP values with GradientExplainer")
    shap_values = explainer.shap_values(input_values)
    logger.info(f"Raw SHAP values type: {type(shap_values)}")
    
    # Convert to numpy and process
    # if isinstance(shap_values[0], torch.Tensor):
    #     shap_values = [v.cpu().numpy() for v in shap_values]
    
    # Convert to numpy array and handle shapes
    # shap_values = np.array(shap_values)  # Shape: (1, batch, seq_len)
    logger.info(f"SHAP values shape after conversion: {shap_values.shape}")
    
    return shap_values

In [6]:
def compute_metrics(processor, device, wrapped_model, model, vocab, test_set: List[Dict]) -> Dict:
    """Compute evaluation metrics for the test set"""
    logger.info("Computing metrics for test set")
    metrics = {
        "shap_noise_correlation": [],
        "shap_confidence_correlation": [],
        "wer_correlation": []
    }
    
    # Store SHAP values and other computed values for visualization
    visualization_data = []
    
    for i, sample in enumerate(tqdm(test_set, desc="Computing metrics")):
        logger.info(f"Processing sample {i+1}/{len(test_set)}")
        audio = sample["audio"]
        text = sample["text"]
        # Get model prediction and confidence
        inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
        input_values = inputs.input_values.to(device)
        
        with torch.no_grad():
            logits = model(input_values).logits
            probs = torch.softmax(logits, dim=-1)
            confidence = torch.mean(torch.max(probs, dim=-1)[0]).item()
        logger.info(f"Model confidence: {confidence:.4f}")
        
        # Compute SHAP values
        shap_values = compute_shap_values(processor, device, wrapped_model, model, vocab, audio)
        logger.info(f"SHAP values shape: {shap_values.shape}")
        logger.info(f"SHAP values range: [{np.min(shap_values):.4f}, {np.max(shap_values):.4f}]")

        np.save(f"data/shap_values_sample_{i+1+8}_{sample['type']}_{sample['snr']}",shap_values)
        np.save(f"data/audio_sample_{i+1+8}_{sample['type']}_{sample['snr']}",sample["audio"])
        np.save(f"data/noise_sample_{i+1+8}_{sample['type']}_{sample['snr']}",sample["noise"])
        np.save(f"data/text_sample_{i+1+8}_{sample['type']}_{sample['snr']}.npy",text)

In [7]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

# Load model and processor
logger.info(f"Loading model: {model_name}")
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)
model = model.to(device)
vocab = {"<pad>": 0, "<s>": 1, "</s>": 2, "<unk>": 3, "|": 4, "E": 5, "T": 6, "A": 7, "O": 8, "N": 9, "I": 10, "H": 11, "S": 12, "R": 13, "D": 14, "L": 15, "U": 16, "M": 17, "W": 18, "C": 19, "F": 20, "G": 21, "Y": 22, "P": 23, "B": 24, "V": 25, "K": 26, "'": 27, "X": 28, "J": 29, "Q": 30, "Z": 31}
logger.info("Model loaded successfully")

# Create model wrapper for SHAP
wrapped_model = _create_model_wrapper(model)
logger.info("Model wrapper created")

2025-09-15 22:38:07,342 - INFO - Using device: cuda
2025-09-15 22:38:07,343 - INFO - Loading model: facebook/wav2vec2-base-960h


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


2025-09-15 22:38:11,762 - INFO - Model loaded successfully
2025-09-15 22:38:11,763 - INFO - Model wrapper created


In [8]:
# Initialize evaluator


# Create test set
logger.info("Creating test set...")
test_set = create_test_set(num_samples=1)
logger.info(test_set)

2025-09-15 22:38:14,162 - INFO - Creating test set...
2025-09-15 22:38:14,163 - INFO - Creating test set with 1 samples


Creating test samples:   0%|                                                                      | 0/1 [00:00<?, ?it/s]

2025-09-15 22:38:23,690 - INFO - Added clean sample 1




2025-09-15 22:38:23,698 - INFO - Added noisy sample 1 with SNR 5dB
2025-09-15 22:38:23,704 - INFO - Added noisy sample 1 with SNR 2dB
2025-09-15 22:38:23,710 - INFO - Added noisy sample 1 with SNR 1dB


Creating test samples: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.20s/it]

2025-09-15 22:38:23,713 - INFO - Test set created with 4 total samples
2025-09-15 22:38:23,714 - INFO - [{'type': 'clean', 'audio': array([-0.00039673, -0.00064087,  0.00296021, ...,  0.00305176,
       -0.00045776, -0.00085449]), 'text': 'I HAVE REMAINED A PRISONER ONLY BECAUSE I WISHED TO BE ONE AND WITH THIS HE STEPPED FORWARD AND BURST THE STOUT CHAINS AS EASILY AS IF THEY HAD BEEN THREADS', 'snr': inf, 'noise': array([0., 0., 0., ..., 0., 0., 0.])}, {'type': 'noisy', 'audio': array([0.02037643, 0.010542  , 0.00720467, ..., 0.00316433, 0.03490779,
       0.0047002 ]), 'text': 'I HAVE REMAINED A PRISONER ONLY BECAUSE I WISHED TO BE ONE AND WITH THIS HE STEPPED FORWARD AND BURST THE STOUT CHAINS AS EASILY AS IF THEY HAD BEEN THREADS', 'snr': 5, 'noise': array([0.02077316, 0.01118287, 0.00424446, ..., 0.00011257, 0.03536556,
       0.0055547 ])}, {'type': 'noisy', 'audio': array([ 0.01691775,  0.03277654, -0.09975552, ..., -0.04221251,
        0.03097612,  0.00542498]), 'text': 'I HAV




In [19]:
print(len(test_set))
sample = test_set[0]
np.save(f"data/audio_sample_{0}_{sample['type']}_{sample['snr']}",sample["audio"])
audio = np.load("data/audio_sample_0_clean_inf.npy")
sf.write("data/test_audio.wav",audio, 16000)

44


In [9]:
# Compute metrics and get visualization data
logger.info("Computing metrics...")
compute_metrics(processor, device, wrapped_model, model, vocab, test_set)

2025-09-15 22:38:30,842 - INFO - Computing metrics...
2025-09-15 22:38:30,843 - INFO - Computing metrics for test set


Computing metrics:   0%|                                                                          | 0/4 [00:00<?, ?it/s]

2025-09-15 22:38:30,845 - INFO - Processing sample 1/4
2025-09-15 22:38:31,338 - INFO - Model confidence: 0.9759
2025-09-15 22:38:31,339 - INFO - Computing SHAP values
2025-09-15 22:38:31,371 - INFO - Created background samples with shape torch.Size([5, 183600])
2025-09-15 22:38:31,372 - INFO - Background mean: -0.000007
2025-09-15 22:38:31,373 - INFO - Background std: 0.010000
2025-09-15 22:38:31,748 - INFO - Model output shape: torch.Size([1, 573])
2025-09-15 22:38:31,749 - INFO - Model output mean: 14.526364
2025-09-15 22:38:31,750 - INFO - Model output std: 1.641433
2025-09-15 22:38:31,766 - INFO - Model output sum: 8323.606445
2025-09-15 22:38:31,802 - INFO - Predicted IDs: tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0, 10,  0,  0,  4,  4, 11, 11,  0,  0,  7, 25,  0,  5,  4,
          4,  4, 13, 13,  0,  0,  5,  0,  0, 17,  0,  0,  0,  7,  7, 10,  0,  9,
          9,  0,  5,  5, 14, 14,  4,  4,  4,  0,  0,  7,  0,  4,  4, 

Computing metrics:  25%|███████████████                                             | 1/4 [1:33:06<4:39:18, 5586.16s/it]

2025-09-16 00:11:37,006 - INFO - Processing sample 2/4
2025-09-16 00:11:37,043 - INFO - Model confidence: 0.8287
2025-09-16 00:11:37,045 - INFO - Computing SHAP values
2025-09-16 00:11:37,047 - INFO - Created background samples with shape torch.Size([5, 183600])
2025-09-16 00:11:37,048 - INFO - Background mean: -0.000002
2025-09-16 00:11:37,050 - INFO - Background std: 0.009996
2025-09-16 00:11:37,318 - INFO - Model output shape: torch.Size([1, 573])
2025-09-16 00:11:37,319 - INFO - Model output mean: 8.384657
2025-09-16 00:11:37,320 - INFO - Model output std: 3.015109
2025-09-16 00:11:37,321 - INFO - Model output sum: 4804.408203
2025-09-16 00:11:37,368 - INFO - Predicted IDs: tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0, 10,  0,  0,  4,  0, 11, 11,  0,  7,  0, 25,  5,  0,  4,
          4,  0, 13,  0,  0,  0,  5,  0,  0,  0, 17,  0,  0,  7, 10,  0,  0,  0,
          9,  0,  5,  5, 14,  4,  4,  0,  0,  0,  7,  6,  0,  4,  4,  

Computing metrics:  50%|██████████████████████████████                              | 2/4 [4:01:39<4:11:26, 7543.46s/it]

2025-09-16 02:40:10,584 - INFO - Processing sample 3/4
2025-09-16 02:40:10,617 - INFO - Model confidence: 0.8386
2025-09-16 02:40:10,618 - INFO - Computing SHAP values
2025-09-16 02:40:10,620 - INFO - Created background samples with shape torch.Size([5, 183600])
2025-09-16 02:40:10,621 - INFO - Background mean: -0.000002
2025-09-16 02:40:10,622 - INFO - Background std: 0.009996
2025-09-16 02:40:10,789 - INFO - Model output shape: torch.Size([1, 573])
2025-09-16 02:40:10,790 - INFO - Model output mean: 7.901667
2025-09-16 02:40:10,791 - INFO - Model output std: 2.484672
2025-09-16 02:40:10,792 - INFO - Model output sum: 4527.655273
2025-09-16 02:40:10,841 - INFO - Predicted IDs: tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0, 10,  0,  0,  4,  4,  0, 11,  0,  7,  0, 25,  5,  0,  4,
          4,  0, 13,  0,  0,  5,  0,  0,  0, 17,  0,  0,  7,  0,  0,  0,  0,  0,
          9,  0,  0,  0,  0,  4,  4,  0,  0,  0,  5,  0,  0,  4,  0,  

Computing metrics:  75%|█████████████████████████████████████████████               | 3/4 [6:32:14<2:17:04, 8224.39s/it]

2025-09-16 05:10:45,282 - INFO - Processing sample 4/4
2025-09-16 05:10:45,316 - INFO - Model confidence: 0.8459
2025-09-16 05:10:45,317 - INFO - Computing SHAP values
2025-09-16 05:10:45,319 - INFO - Created background samples with shape torch.Size([5, 183600])
2025-09-16 05:10:45,320 - INFO - Background mean: -0.000004
2025-09-16 05:10:45,322 - INFO - Background std: 0.010008
2025-09-16 05:10:45,524 - INFO - Model output shape: torch.Size([1, 573])
2025-09-16 05:10:45,525 - INFO - Model output mean: 7.786644
2025-09-16 05:10:45,526 - INFO - Model output std: 2.288420
2025-09-16 05:10:45,527 - INFO - Model output sum: 4461.746582
2025-09-16 05:10:45,587 - INFO - Predicted IDs: tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0, 10,  0,  0,  4,  4,  0, 11,  0,  7, 25,  0,  0,  4,  4,
          0,  6, 11,  0,  7, 25,  0,  0,  4,  0,  0,  0,  0,  0,  0,  5,  0,  0,
          0,  0,  0,  0,  0,  4,  0,  0,  0,  0,  5,  0,  0,  4,  0,  

Computing metrics: 100%|██████████████████████████████████████████████████████████████| 4/4 [9:03:37<00:00, 8154.29s/it]
