In [1]:
import torch
import numpy as np
import logging
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import librosa
import matplotlib.pyplot as plt
import soundfile as sf
from datasets import load_dataset
from tqdm import tqdm
import sys
import shap

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('evaluation.log')
    ]
)
logger = logging.getLogger(__name__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_name = "facebook/wav2vec2-base-960h"

  from .autonotebook import tqdm as notebook_tqdm
2025-09-12 22:13:08.539950: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-09-12 22:13:08.696926: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1757707988.759424    3720 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1757707988.776820    3720 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1757707988.906050    3720 computation_placer.cc:177] computation placer already r

In [2]:
logger.info(f"Loading model: {model_name}")
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)
model = model.to(device)
logger.info("Model loaded successfully")

2025-09-12 22:13:11,616 - INFO - Loading model: facebook/wav2vec2-base-960h


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


2025-09-12 22:13:13,185 - INFO - Model loaded successfully


In [3]:
def plot_spectrograms(S_original, S_amplified, sr):
    """
    Visualizes the original and amplified mel spectrograms side by side.
    """
    # Convert power spectrograms to decibels (dB) for better visualization
    S_db_original = librosa.power_to_db(S_original, ref=np.max)
    S_db_amplified = librosa.power_to_db(S_amplified, ref=np.max)

    # Create figure
    fig, ax = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(12, 8))
    fig.suptitle('Mel Spectrogram Comparison', fontsize=16)

    # Plot original spectrogram
    img1 = librosa.display.specshow(S_db_original, sr=sr, x_axis='time', y_axis='mel', ax=ax[0])
    ax[0].set(title='Original Spectrogram')
    ax[0].label_outer() # Hide x-axis label for the top plot
    fig.colorbar(img1, ax=ax[0], format='%+2.0f dB')

    # Plot amplified spectrogram
    img2 = librosa.display.specshow(S_db_amplified, sr=sr, x_axis='time', y_axis='mel', ax=ax[1])
    ax[1].set(title='Spectrogram with Amplified Quiet Sections')
    fig.colorbar(img2, ax=ax[1], format='%+2.0f dB')

    # Adjust layout and display the plot
    plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust for suptitle
    plt.show()

In [4]:
def plot_waveforms(y_original, y_modified, sr):
    """
    Visualizes the original and a modified audio waveform side by side.

    Args:
        y_original (np.ndarray): The original audio time series.
        y_modified (np.ndarray): The modified audio time series.
        sr (int): The sampling rate of the audio.
    """
    # Create a time array for the x-axis
    time_original = librosa.times_like(y_original, sr=sr)
    time_modified = librosa.times_like(y_modified, sr=sr)

    # Create a figure with two subplots, sharing the x and y axes
    fig, ax = plt.subplots(nrows=2, ncols=1, sharex=True, sharey=True, figsize=(12, 8))
    fig.suptitle('Waveform Comparison', fontsize=16)

    # Plot the original waveform
    librosa.display.waveshow(y_original, sr=sr, ax=ax[0], color='b')
    ax[0].set(title='Original Waveform')
    ax[0].set_ylabel('Amplitude')
    ax[0].grid(True, linestyle='--', alpha=0.6)
    ax[0].label_outer() # Hide x-axis label for the top plot

    # Plot the modified waveform
    librosa.display.waveshow(y_modified, sr=sr, ax=ax[1], color='r')
    ax[1].set(title='Modified Waveform')
    ax[1].set_xlabel('Time (s)')
    ax[1].set_ylabel('Amplitude')
    ax[1].grid(True, linestyle='--', alpha=0.6)

    # Add a legend to distinguish the waveforms
    fig.legend(['Original', 'Modified'], loc='upper right')
    
    # Adjust layout to prevent titles from overlapping and display the plot
    plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust for suptitle
    plt.show()

In [5]:
def normalize_and_scale_shap(shap_values, min_val=0.8, default=0.4):
    """
    Normalize SHAP values to [0, 1] and scale to [min_val, max_val].
    """
    # Normalize to [0, 1]
    shap_min = np.min(shap_values)
    shap_max = np.max(shap_values)
    normalized_shap = (shap_values - shap_min) / (shap_max - shap_min + 1e-8)  # Add small epsilon to avoid division by zero

    # Scale to [min_val, max_val]
    scaled_shap = ((normalized_shap-min_val).clip(0,1)/(1 - min_val)).clip(default, 1)
    return scaled_shap

In [9]:
audio = np.load("data/audio_sample_14_clean_inf.npy")
inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
input_values = inputs.input_values.to(device)

# retrieve logits
logits = model(input_values).logits

# take argmax and decode
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
logger.info(f"Transcription: {transcription}")
logger.info(f"Ouptut shape: {predicted_ids.shape}")

2025-09-12 22:39:48,829 - INFO - Transcription: ['ON THE GENERAL PRINCIPLES OF ART MISTER QUILTER WRITES WITH EQUAL LUCIDITY']
2025-09-12 22:39:48,829 - INFO - Ouptut shape: torch.Size([1, 281])


In [10]:
shap_values = np.load("data/shap_values_sample_14_clean_inf.npy")
shap_values = shap_values.reshape(audio.shape[0], logits.shape[1])
logger.info(f"Loaded SHAP values with shape: {shap_values.shape}")
logger.info(f"Loaded audio with shape: {audio.shape}")

2025-09-12 22:40:05,808 - INFO - Loaded SHAP values with shape: (90240, 281)
2025-09-12 22:40:05,809 - INFO - Loaded audio with shape: (90240,)


In [None]:
characters = [28,  29,  31,  32,  33,  35,  37,  39,  40,  41,  43,  45,  48,  49,
         50,  53,  57,  58,  59,  61,  62,  63,  66,  67,  68,  69,  70,  71,
         72,  73,  74,  75,  76,  77,  80,  84,  90,  91,  92,  95,  96,  98,
         99, 100, 101, 102, 103, 104, 105, 106, 109, 110, 112, 113, 114, 115,
        116, 117, 118, 119, 120, 121, 122, 123, 126, 127, 128, 129, 130, 131,
        132, 133, 136, 142, 143, 147, 152, 154, 155, 156, 157, 158, 168, 169,
        170, 171, 172, 173, 174, 177, 179, 181, 182, 183, 184, 185, 186, 187,
        189, 195, 199, 201, 202, 203, 205, 207, 208, 209, 210, 212, 215, 217,
        221, 224, 226, 227, 228, 229, 230, 231, 232, 234, 237, 238, 239, 241,
        242, 243, 245, 251, 252, 253, 254, 258, 262, 263, 265, 266, 268, 269,
        270, 271]
character_shap_values = [np.zeros((93680,)) for _ in characters]
for idx, char in enumerate(characters):
    for i, timestep in enumerate(shap_values[1][0]):
        character_shap_values[idx][i] = abs(timestep[char])

window_length_ms = 20

num_of_frames = window_length_ms*16

for arr in character_shap_values:
    for idx in range(0,len(arr),num_of_frames):
        mean = np.mean(arr[idx:min(idx+num_of_frames,len(arr))])
        arr[idx:min(idx+num_of_frames,len(arr))] = mean

character_norm_shap_values = [normalize_and_scale_shap(vals,0.80,0.0) for vals in character_shap_values]

#print(norm_m_shap_values.shape)

character_amplified_audio = [audio * vals for vals in character_norm_shap_values]

amplified_audio = sum(character_amplified_audio)/len(character_amplified_audio) * 10
S_amplified_audio = librosa.feature.melspectrogram(y=amplified_audio, sr=16000, n_fft=2048, hop_length=512)
S_audio = librosa.feature.melspectrogram(y=audio, sr=16000, n_fft=2048, hop_length=512)
sf.write("m_amplified_audio.wav", amplified_audio, 16000)
plot_spectrograms(S_audio, S_amplified_audio, sr=16000)
plot_waveforms(audio, amplified_audio*5, sr=16000)

IndexError: invalid index to scalar variable.