In [None]:
import os
import IPython.display as ipd
from IPython.display import Audio, display, HTML

import soundfile as sf
import matplotlib.pyplot as plt
import pandas as pd
import torch

from speechbrain.augment.time_domain import DropChunk, DropFreq, AddReverb, AddNoise, DropBitResolution, SpeedPerturb
from speechbrain.processing.features import InputNormalization
from speechbrain.lobes.features import Fbank

In [None]:
os.chdir('/home/ahmad/adversarial-robustness-for-sr')

## Dataloading

In [None]:
path1 = 'data/voxceleb/voxceleb1_2/id00012/21Uxsk56VDQ/00001.wav'
path2 = 'data/voxceleb/voxceleb1_2/id00012/21Uxsk56VDQ/00002.wav'

audio1, sr1 = sf.read(path1)
audio2, sr2 = sf.read(path2)

assert sr1 == sr2
sr = sr1

min_len = min(len(audio1), len(audio2))
if min_len == len(audio1):
    audio2 = audio2[: min_len]
else:
    audio1 = audio1[: min_len]

audio1 = torch.tensor(audio1).float().unsqueeze(0)
audio2 = torch.tensor(audio2).float().unsqueeze(0)
audio = torch.cat((audio1, audio2), dim=0)

## Utils Definition

In [None]:
def normalize_audio(waveform: torch.Tensor, method: str = "rms", target_level: float = -20.0,
                    clip_limit: float = 0.999, epsilon: float = 1e-8) -> torch.Tensor:
    """
    Normalize audio using various methods to handle different scenarios better than simple max normalization.
    
    Args:
        waveform: Input waveform tensor of shape (1, samples)
        method: Normalization method: 'peak', 'rms', 'percentile', or 'dynamic'
        target_level: Target level in dB for RMS normalization (typically -20 dB)
        clip_limit: Clipping limit for preventing excessive amplification (0-1)
        epsilon: Small value to prevent division by zero
        
    Returns:
        Normalized waveform tensor of shape (1, samples)
    """
    waveform = waveform.squeeze(0)
    assert waveform.dim() == 1, "Expected single-channel audio"
    
    # Center the waveform by removing DC offset
    waveform = waveform - waveform.mean()
    
    if method == "peak":
        # Traditional peak normalization
        peak = waveform.abs().max() + epsilon
        waveform = waveform / peak
        
    elif method == "rms":
        # RMS normalization (based on signal energy)
        rms = torch.sqrt(torch.mean(waveform ** 2))
        target_rms = 10 ** (target_level / 20)  # Convert dB to linear
        gain = target_rms / (rms + epsilon)
        waveform = waveform * gain
        
    elif method == "percentile":
        # Percentile-based normalization (robust to outliers)
        sorted_abs = torch.sort(waveform.abs())[0]
        idx = min(int(len(sorted_abs) * 0.995), len(sorted_abs) - 1)
        ref_level = sorted_abs[idx] + epsilon
        waveform = waveform / ref_level
        
    elif method == "dynamic":
        # Dynamic range compression (logarithmic compression)
        sign = torch.sign(waveform)
        abs_wave = waveform.abs() + epsilon
        compressed = sign * torch.log1p(abs_wave) / torch.log1p(torch.tensor(1.0))
        peak = compressed.abs().max() + epsilon
        waveform = compressed / peak
    
    else:
        raise ValueError(f"Unsupported normalization method: {method}")
    
    # Apply clipping to prevent excessive values
    if clip_limit < 1.0:
        waveform = torch.clamp(waveform, min=-clip_limit, max=clip_limit)
    
    return waveform.unsqueeze(0)


# Function that returns the HTML for an audio player
def audio_player_html(audio_array, sr):
    # Create an IPython Audio object and return its HTML representation
    return Audio(audio_array, rate=sr)._repr_html_()


## Waveform Normalization

In [None]:
# Get normalized waveforms
norm_method = 'rms'
audio1_norm = normalize_audio(audio1, norm_method)
audio2_norm = normalize_audio(audio2, norm_method)
audio_norm = torch.cat((audio1_norm, audio2_norm), dim=0)

audio_data = {
    'Sample': [f"{path1[path1.index('voxceleb1_2') + len('voxceleb1_2') + 1:]}", 
               f"{path2[path2.index('voxceleb1_2') + len('voxceleb1_2') + 1:]}"
              ],
    'Raw Audio': [audio[0], audio[1]],
    f'{norm_method.upper()} Normalized Audio': [audio_norm[0], audio_norm[1]],
}

# Build the HTML table dynamically
headers = list(audio_data.keys())
table_html = '<table border="1" style="border-collapse: collapse; text-align: center;">'

# Create header row with centered headings
table_html += '<tr>'
for header in headers:
    table_html += f'<th style="padding: 8px; text-align: center;">{header}</th>'
table_html += '</tr>'

# Determine the number of samples (assuming all lists are the same length)
n_samples = len(audio_data['Sample'])

# Create table rows for each sample
for i in range(n_samples):
    row_html = '<tr>'
    for key in headers:
        if key == 'Sample':
            row_html += f'<td style="padding: 8px;">{audio_data[key][i]}</td>'
        else:
            audio_array = audio_data[key][i]
            player_html = audio_player_html(audio_array, sr)
            row_html += f'<td style="padding: 8px;">{player_html}</td>'
    row_html += '</tr>'
    table_html += row_html

table_html += '</table>'

# Display the complete table in the Jupyter Notebook
display(HTML(table_html))

## Features Normalization

In [None]:
# Extract Features 
audio_processor = Fbank(sample_rate=16000, deltas=False, n_mels=80, f_min=0, f_max=None,
                        n_fft=512, win_length=25, hop_length=10, left_frames=0, right_frames=0)
audio_feats = audio_processor(audio)
audio_norm_feats = audio_processor(audio_norm)

# Features normalization
mfcc_norm = InputNormalization(norm_type='sentence', std_norm=False)
audio_feats_norm = mfcc_norm(audio_feats, lengths=torch.tensor([1, 1]))
audio_norm_feats_norm = mfcc_norm(audio_norm_feats, lengths=torch.tensor([1, 1]))

# Plot features
T_max = 320

plt.imshow(audio_feats[1, :T_max, ...].T, origin='lower')
plt.colorbar()
plt.show()

plt.imshow(audio_norm_feats[1, :T_max, ...].T, origin='lower')
plt.colorbar()
plt.show()

plt.imshow(audio_feats_norm[1, :T_max, ...].T, origin='lower')
plt.colorbar()
plt.show()

plt.imshow(audio_norm_feats_norm[1, :T_max, ...].T, origin='lower')
plt.colorbar()
plt.show()

plt.imshow(abs(audio_feats_norm[1, :T_max, ...] - audio_norm_feats_norm[1, :T_max, ...]).T, origin='lower')
plt.colorbar()
plt.show()

plt.imshow(abs(audio_norm_feats[1, :T_max, ...] - audio_feats[1, :T_max, ...]).T, origin='lower')
plt.colorbar()
plt.show()

## DropChunk

In [None]:
chunk_dropper = DropChunk(drop_length_low=2000, drop_length_high=8000)
chunk_dropped = chunk_dropper(audio, torch.tensor([1, 1]))

## DropFreq

In [None]:
freq_dropper = DropFreq(drop_freq_count_low=1, drop_freq_count_high=3)
freq_dropped = freq_dropper(audio)

## AddReverb

In [None]:
reverb = AddReverb('data/voxceleb/RIRS_NOISES/reverb.csv')
reverbed = reverb(audio)

## AddNoise

In [None]:
noisifier = AddNoise('data/voxceleb/RIRS_NOISES/noise.csv')
noisy = noisifier(audio, lengths=torch.tensor([1, 1]))

## DropBitResolution

In [None]:
bit_dropper = DropBitResolution()
bit_dropped = bit_dropper(audio)

## SpeedPerturb

In [None]:
speed_perturber = SpeedPerturb(orig_freq=sr, speeds=[90, 110], device='cuda')
speeded = speed_perturber(audio)

In [None]:
# Define paths for different samples under different conditions
audio_data = {
    'Sample': [f"{path1[path1.index('voxceleb1_2') + len('voxceleb1_2') + 1:]}", 
               f"{path2[path2.index('voxceleb1_2') + len('voxceleb1_2') + 1:]}"
              ],
    'Clean': [audio[0], audio[1] ],
    'Noisy': [noisy[0], noisy[1]],
    'Reverberated': [reverbed[0], reverbed[1]],
    'Frequency Dropped': [freq_dropped[0], freq_dropped[1] ],
    'Chunks Dropped': [chunk_dropped[0], chunk_dropped[1]],
    'Bit Dropped': [bit_dropped[0], bit_dropped[1]],
    'Speed Pertubed': [speeded[0], speeded[1]]
}

# Build the HTML table dynamically
headers = list(audio_data.keys())
table_html = '<table border="1" style="border-collapse: collapse; text-align: center;">'

# Create header row with centered headings
table_html += '<tr>'
for header in headers:
    table_html += f'<th style="padding: 8px; text-align: center;">{header}</th>'
table_html += '</tr>'

# Determine the number of samples (assuming all lists are the same length)
n_samples = len(audio_data['Sample'])

# Create table rows for each sample
for i in range(n_samples):
    row_html = '<tr>'
    for key in headers:
        if key == 'Sample':
            row_html += f'<td style="padding: 8px;">{audio_data[key][i]}</td>'
        else:
            audio_array = audio_data[key][i]
            player_html = audio_player_html(audio_array, sr)
            row_html += f'<td style="padding: 8px;">{player_html}</td>'
    row_html += '</tr>'
    table_html += row_html

table_html += '</table>'

# Display the complete table in the Jupyter Notebook
display(HTML(table_html))