In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import shap
import logging
import sys
import os
import soundfile as sf
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from datasets import load_dataset
import librosa
import librosa.display
from scipy.stats import pearsonr
from typing import List, Dict, Tuple
import json
from pathlib import Path
from tqdm import tqdm
import warnings

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('evaluation.log')
    ]
)
logger = logging.getLogger(__name__)

model_name = "facebook/wav2vec2-base-960h"
            

  from .autonotebook import tqdm as notebook_tqdm
2025-09-15 13:03:36.973818: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-09-15 13:03:36.983051: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1757934216.992809   29746 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1757934216.995646   29746 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1757934217.004638   29746 computation_placer.cc:177] computation placer already r

In [2]:
def _create_model_wrapper(model):
    """Create a wrapper class for the model to get properly shaped output"""
    class ModelWrapper(torch.nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model
        
        def forward(self, x):
            # Ensure input has correct shape for the model
            if len(x.shape) == 4:
                x = x.squeeze(1).squeeze(1)  # Remove extra dimensions
            elif len(x.shape) == 3:
                x = x.squeeze(1)  # Remove extra dimension
            
            # Add attention mask
            attention_mask = torch.ones_like(x)
            
            # Forward pass with attention mask
            logits = self.model(x, attention_mask=attention_mask).logits
            
            # Log the shape and statistics of logits
            logger.debug(f"Logits shape: {logits.shape}")
            logger.debug(f"Logits mean: {torch.mean(logits).item():.6f}")
            logger.debug(f"Logits std: {torch.std(logits).item():.6f}")
            
            # For SHAP, aggregate over vocab to get a scalar per time step
            return torch.max(logits,dim=-1).values  # [batch, seq_len]
    
    return ModelWrapper(model)

In [3]:
def _add_noise(audio: np.ndarray, snr_db: float) -> np.ndarray:
    """Add white noise to audio at specified SNR"""
    signal_power = np.mean(audio ** 2)
    noise_power = signal_power / (10 ** (snr_db / 10))
    noise = np.random.normal(0, np.sqrt(noise_power), len(audio))
    return audio + noise

In [None]:
def create_test_set(num_samples: int = 10) -> Dict:
    """Create a controlled test set with various conditions"""
    logger.info(f"Creating test set with {num_samples} samples")
    ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
    test_set = []
    dataset_index = 15

    for i in tqdm(range(min(num_samples, len(ds))), desc="Creating test samples"):
        sample = ds[i+dataset_index]
        audio = sample["audio"]["array"]
        while len(audio) < 100000:
            dataset_index += 1
            sample = ds[i+dataset_index]
            audio = sample["audio"]["array"]
        text = sample["text"]
        
        # Create clean sample
        test_set.append({
            "type": "clean",
            "audio": audio,
            "text": text,
            "snr": float('inf'),
            "noise": np.zeros_like(audio)
        })
        logger.info(f"Added clean sample {i+1}")
        
        # Create noisy samples with different SNRs [20, 10, 0, -5]
        for snr in tqdm([5,2,1], desc=f"Adding noise to sample {i+1}", leave=False):
            noisy_audio = _add_noise(audio, snr)
            test_set.append({
                "type": "noisy",
                "audio": noisy_audio,
                "text": text,
                "snr": snr,
                "noise": noisy_audio - audio
            })
            logger.info(f"Added noisy sample {i+1} with SNR {snr}dB")

    logger.info(f"Test set created with {len(test_set)} total samples")
    return test_set

In [5]:
def compute_shap_values(processor, device, wrapped_model, model, vocab, audio: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Compute SHAP values for an audio sample using GradientExplainer"""
    logger.info("Computing SHAP values")
    # Process audio
    inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
    input_values = inputs.input_values.to(device)
    
    # Ensure input_values has the correct shape [batch_size, sequence_length]
    if len(input_values.shape) == 3:
        input_values = input_values.squeeze(1)
    
    # Create background samples with correct shape [batch_size, sequence_length]
    num_background = 5
    background = torch.zeros((num_background, input_values.shape[1]), device=device)
    background += torch.randn_like(background) * 0.01  # Add small random noise
    logger.info(f"Created background samples with shape {background.shape}")
    logger.info(f"Background mean: {torch.mean(background).item():.6f}")
    logger.info(f"Background std: {torch.std(background).item():.6f}")
    
    # Initialize GradientExplainer with the model only
    explainer = shap.GradientExplainer(
        wrapped_model,
        background,
        batch_size=1
    )
    
    # Log model output before SHAP computation
    with torch.no_grad():
        model_output = wrapped_model(input_values)
        logger.info(f"Model output shape: {model_output.shape}")
        logger.info(f"Model output mean: {torch.mean(model_output).item():.6f}")
        logger.info(f"Model output std: {torch.std(model_output).item():.6f}")
        logger.info(f"Model output sum: {torch.sum(model_output).item():.6f}")
    
    # retrieve logits
    logits = model(input_values).logits

    # take argmax and decode
    predicted_ids = torch.argmax(logits, dim=-1)
    logger.info(f"Predicted IDs: {predicted_ids}")
    transcription = processor.batch_decode(predicted_ids)
    logger.info(f"Transcription: {transcription}")

    output_string = ''.join([list(vocab.keys())[list(vocab.values()).index(id.item())] for id in predicted_ids[0]])
    logger.info(f"Decoded output string: {output_string}")

    # Get SHAP values
    logger.info("Computing SHAP values with GradientExplainer")
    shap_values = explainer.shap_values(input_values)
    logger.info(f"Raw SHAP values type: {type(shap_values)}")
    
    # Convert to numpy and process
    # if isinstance(shap_values[0], torch.Tensor):
    #     shap_values = [v.cpu().numpy() for v in shap_values]
    
    # Convert to numpy array and handle shapes
    # shap_values = np.array(shap_values)  # Shape: (1, batch, seq_len)
    logger.info(f"SHAP values shape after conversion: {shap_values.shape}")
    
    return shap_values

In [6]:
def compute_metrics(processor, device, wrapped_model, model, vocab, test_set: List[Dict]) -> Dict:
    """Compute evaluation metrics for the test set"""
    logger.info("Computing metrics for test set")
    metrics = {
        "shap_noise_correlation": [],
        "shap_confidence_correlation": [],
        "wer_correlation": []
    }
    
    # Store SHAP values and other computed values for visualization
    visualization_data = []
    
    for i, sample in enumerate(tqdm(test_set, desc="Computing metrics")):
        logger.info(f"Processing sample {i+1}/{len(test_set)}")
        audio = sample["audio"]
        text = sample["text"]
        # Get model prediction and confidence
        inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
        input_values = inputs.input_values.to(device)
        
        with torch.no_grad():
            logits = model(input_values).logits
            probs = torch.softmax(logits, dim=-1)
            confidence = torch.mean(torch.max(probs, dim=-1)[0]).item()
        logger.info(f"Model confidence: {confidence:.4f}")
        
        # Compute SHAP values
        shap_values = compute_shap_values(processor, device, wrapped_model, model, vocab, audio)
        logger.info(f"SHAP values shape: {shap_values.shape}")
        logger.info(f"SHAP values range: [{np.min(shap_values):.4f}, {np.max(shap_values):.4f}]")

        np.save(f"data/shap_values_sample_{i+1}_{sample['type']}_{sample['snr']}",shap_values)
        np.save(f"data/audio_sample_{i+1}_{sample['type']}_{sample['snr']}",sample["audio"])
        np.save(f"data/noise_sample_{i+1}_{sample['type']}_{sample['snr']}",sample["noise"])
        np.save(f"data/text_sample_{i+1}_{sample['type']}_{sample['snr']}.npy",text)

In [7]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

# Load model and processor
logger.info(f"Loading model: {model_name}")
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)
model = model.to(device)
vocab = {"<pad>": 0, "<s>": 1, "</s>": 2, "<unk>": 3, "|": 4, "E": 5, "T": 6, "A": 7, "O": 8, "N": 9, "I": 10, "H": 11, "S": 12, "R": 13, "D": 14, "L": 15, "U": 16, "M": 17, "W": 18, "C": 19, "F": 20, "G": 21, "Y": 22, "P": 23, "B": 24, "V": 25, "K": 26, "'": 27, "X": 28, "J": 29, "Q": 30, "Z": 31}
logger.info("Model loaded successfully")

# Create model wrapper for SHAP
wrapped_model = _create_model_wrapper(model)
logger.info("Model wrapper created")

2025-09-15 13:03:39,077 - INFO - Using device: cuda
2025-09-15 13:03:39,078 - INFO - Loading model: facebook/wav2vec2-base-960h


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


2025-09-15 13:03:40,785 - INFO - Model loaded successfully
2025-09-15 13:03:40,786 - INFO - Model wrapper created


In [8]:
# Initialize evaluator


# Create test set
logger.info("Creating test set...")
test_set = create_test_set(num_samples=3)
logger.info(test_set)

2025-09-15 13:03:40,792 - INFO - Creating test set...
2025-09-15 13:03:40,792 - INFO - Creating test set with 3 samples


Creating test samples: 100%|██████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  3.34it/s]

2025-09-15 13:03:42,953 - INFO - Test set created with 0 total samples
2025-09-15 13:03:42,954 - INFO - []





In [19]:
print(len(test_set))
sample = test_set[0]
np.save(f"data/audio_sample_{0}_{sample['type']}_{sample['snr']}",sample["audio"])
audio = np.load("data/audio_sample_0_clean_inf.npy")
sf.write("data/test_audio.wav",audio, 16000)

44


In [9]:
# Compute metrics and get visualization data
logger.info("Computing metrics...")
compute_metrics(processor, device, wrapped_model, model, vocab, test_set)

2025-09-15 13:03:46,062 - INFO - Computing metrics...
2025-09-15 13:03:46,063 - INFO - Computing metrics for test set


Computing metrics: 0it [00:00, ?it/s]
