# NLP Project

This notebook is the code associated with our work for the MVA Course "Algorithms for speech and language processing". In this notebook we reimplement and extend the work from "Augmentation Invariant Discrete Representation for
Generative Spoken Language Modeling" by Itai Gat et al.

- Encoder extraction : Adrien Letellier
- Quantizer training & testing : Raphaël Bernas
- Augmentation : Maxime Corlay

In [1]:
!pip install --pre torch torchvision torchaudio
!pip install numpy
!pip install transformers
!pip install datasets
!pip install scikit-learn
!pip install librosa
!pip install soundfile
!pip install python-Levenshtein


Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [2]:
import torch
from transformers import Wav2Vec2FeatureExtractor, HubertModel, Wav2Vec2Model, WavLMModel
from datasets import load_dataset
import numpy as np
from sklearn.cluster import KMeans

In [3]:
# C'est un dataset similaire à celui qu'on veut mais plus petit
# Tous les datasets sont sur HuggingFace de toute façon donc c'est facile à changer
dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True)
dataset = dataset.sort("id")
sampling_rate = dataset.features["audio"].sampling_rate

librispeech_asr_demo.py:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

Generating validation split: 0 examples [00:00, ? examples/s]

## 1. Encoder extraction

### 1.1 HuBERT

#### 1.1.1 Speech Encoder

In [None]:
# Load feature extractor and HuBERT model
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/hubert-base-ls960")
model = HubertModel.from_pretrained("facebook/hubert-base-ls960", output_hidden_states=True)
model.eval()

HubertModel(
  (feature_extractor): HubertFeatureEncoder(
    (conv_layers): ModuleList(
      (0): HubertGroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (activation): GELUActivation()
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1-4): 4 x HubertNoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (5-6): 2 x HubertNoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
    )
  )
  (feature_projection): HubertFeatureProjection(
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (projection): Linear(in_features=512, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): HubertEncoder(
    (pos_conv_embed): HubertPositionalConvEmbedding(
      (conv): Para

In [None]:
# Select first point of the dataset
input_values = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")['input_values']

In [None]:
with torch.no_grad():
    outputs = model(input_values)

In [None]:
len(outputs['hidden_states'])

13

In [None]:
outputs['last_hidden_state']

tensor([[[ 0.0924, -0.0873,  0.2480,  ..., -0.0481,  0.1011, -0.3813],
         [ 0.1171, -0.0870,  0.2565,  ..., -0.0525,  0.0991, -0.4402],
         [ 0.1896, -0.0639,  0.2879,  ..., -0.0714,  0.0727, -0.5391],
         ...,
         [ 0.1721,  0.3426,  0.0415,  ..., -0.0303, -0.1977, -0.6863],
         [ 0.1121,  0.1157,  0.1866,  ..., -0.1068, -0.1563, -0.5571],
         [ 0.0897,  0.0344,  0.2302,  ..., -0.0846, -0.0011, -0.4501]]])

In [None]:
outputs['hidden_states'][-1]

tensor([[[ 0.0924, -0.0873,  0.2480,  ..., -0.0481,  0.1011, -0.3813],
         [ 0.1171, -0.0870,  0.2565,  ..., -0.0525,  0.0991, -0.4402],
         [ 0.1896, -0.0639,  0.2879,  ..., -0.0714,  0.0727, -0.5391],
         ...,
         [ 0.1721,  0.3426,  0.0415,  ..., -0.0303, -0.1977, -0.6863],
         [ 0.1121,  0.1157,  0.1866,  ..., -0.1068, -0.1563, -0.5571],
         [ 0.0897,  0.0344,  0.2302,  ..., -0.0846, -0.0011, -0.4501]]])

In [None]:
# Extract encoder output (this includes feature extraction)
encoder_output = outputs['last_hidden_state']  # Shape: (1, seq_len, feature_dim)
print("Encoder Output Shape:", encoder_output.shape)

Encoder Output Shape: torch.Size([1, 292, 768])


In [None]:
encoder_output

tensor([[[ 0.0924, -0.0873,  0.2480,  ..., -0.0481,  0.1011, -0.3813],
         [ 0.1171, -0.0870,  0.2565,  ..., -0.0525,  0.0991, -0.4402],
         [ 0.1896, -0.0639,  0.2879,  ..., -0.0714,  0.0727, -0.5391],
         ...,
         [ 0.1721,  0.3426,  0.0415,  ..., -0.0303, -0.1977, -0.6863],
         [ 0.1121,  0.1157,  0.1866,  ..., -0.1068, -0.1563, -0.5571],
         [ 0.0897,  0.0344,  0.2302,  ..., -0.0846, -0.0011, -0.4501]]])

#### 1.1.2 Quantizer

In [None]:
# Apply K-Means clustering
num_clusters = 50
features = encoder_output.squeeze(0).numpy()
features.shape  # Shape: (seq_len, feature_dim)

(292, 768)

In [None]:
kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init=10)
kmeans.fit(features)

In [None]:
# Convert encoder output to discrete representations
quantized_ids = kmeans.predict(features)
print("Discrete Representation (First 20 IDs):", quantized_ids[:20])

Discrete Representation (First 20 IDs): [33 33 33 33  9  9  9  9  9  9  9  9  9  4  4  4  4  4  4  4]


### 1.2 wav2vec2

#### 1.2.1 Speech Encoder

In [None]:
# Load feature extractor and model
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
model.eval()



Wav2Vec2Model(
  (feature_extractor): Wav2Vec2FeatureEncoder(
    (conv_layers): ModuleList(
      (0): Wav2Vec2GroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (activation): GELUActivation()
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
    )
  )
  (feature_projection): Wav2Vec2FeatureProjection(
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (projection): Linear(in_features=512, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): Wav2Vec2Encoder(
    (pos_conv_embed): Wav2Vec2PositionalConvEmbedding(
  

In [None]:
# Select first point of the dataset
input_values = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")['input_values']

In [None]:
with torch.no_grad():
    outputs = model(input_values)

In [None]:
# Extract encoder output (this includes feature extraction)
encoder_output = outputs['last_hidden_state']  # Shape: (1, seq_len, feature_dim)
print("Encoder Output Shape:", encoder_output.shape)

Encoder Output Shape: torch.Size([1, 292, 768])


In [None]:
encoder_output

tensor([[[ 0.0252, -0.0161,  0.1962,  ...,  0.5132,  0.2121, -0.1114],
         [-0.3064, -0.0877,  0.0485,  ...,  0.2346,  0.6384, -0.3538],
         [ 0.2099,  0.1193,  0.5077,  ...,  0.0555,  0.3368,  0.2325],
         ...,
         [-0.3104, -0.0688,  0.0304,  ...,  0.1952,  0.6314, -0.3537],
         [-0.3162, -0.0806,  0.0095,  ...,  0.1865,  0.6372, -0.3541],
         [-0.0199, -0.0527,  0.0903,  ...,  0.3927,  0.2868, -0.3366]]])

#### 1.2.2 Quantizer

In [None]:
# Apply K-Means clustering
num_clusters = 50
features = encoder_output.squeeze(0).numpy()
features.shape  # Shape: (seq_len, feature_dim)

(292, 768)

In [None]:
kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init=10)
kmeans.fit(features)

In [None]:
# Convert encoder output to discrete representations
quantized_ids = kmeans.predict(features)
print("Discrete Representation (First 20 IDs):", quantized_ids[:20])

Discrete Representation (First 20 IDs): [29  0  1  1  0  1  0  1  1  1  0  0  0  1 25 10 10 10 41  1]


### 1.3 WavLM

#### 1.3.1 Speech Encoder

In [4]:
# Load feature extractor and model
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base")
model = WavLMModel.from_pretrained("microsoft/wavlm-base")
model.eval()

preprocessor_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/2.24k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/378M [00:00<?, ?B/s]

WavLMModel(
  (feature_extractor): WavLMFeatureEncoder(
    (conv_layers): ModuleList(
      (0): WavLMGroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (activation): GELUActivation()
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1-4): 4 x WavLMNoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (5-6): 2 x WavLMNoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
    )
  )
  (feature_projection): WavLMFeatureProjection(
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (projection): Linear(in_features=512, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): WavLMEncoder(
    (pos_conv_embed): WavLMPositionalConvEmbedding(
      (conv): Parametrized

In [5]:
# Select first point of the dataset
print("Audio Shape:", dataset[0]["audio"]["array"].shape)
input_values = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")['input_values']
print("Input Values Shape:", input_values.shape)

Audio Shape: (93680,)
Input Values Shape: torch.Size([1, 93680])


In [6]:
with torch.no_grad():
    outputs = model(input_values)

In [7]:
# Extract encoder output (this includes feature extraction)
encoder_output = outputs['last_hidden_state']  # Shape: (1, seq_len, feature_dim)
print("Encoder Output Shape:", encoder_output.shape)

Encoder Output Shape: torch.Size([1, 292, 768])


In [8]:
encoder_output

tensor([[[-0.1524, -0.2139, -0.1196,  ...,  1.2128,  0.2217, -0.3977],
         [-0.1470, -0.2864, -0.0996,  ...,  1.2637,  0.2217, -0.4654],
         [-0.1055, -0.3247, -0.1150,  ...,  1.3419,  0.2127, -0.4782],
         ...,
         [ 0.0136, -0.2798, -0.4029,  ...,  0.9122,  0.2058, -0.3439],
         [-0.0423, -0.2395, -0.4088,  ...,  0.9519,  0.1429, -0.4677],
         [-0.1248, -0.2294, -0.2764,  ...,  0.9044,  0.1780, -0.5477]]])

#### 1.3.2 Quantizer

In [9]:
# Apply K-Means clustering
num_clusters = 50
features = encoder_output.squeeze(0).numpy()
features.shape  # Shape: (seq_len, feature_dim)

(292, 768)

In [10]:
kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init=10)
kmeans.fit(features)

In [11]:
# Convert encoder output to discrete representations
quantized_ids = kmeans.predict(features)
print("Discrete Representation (First 20 IDs):", quantized_ids[:20])
print(quantized_ids.shape)

Discrete Representation (First 20 IDs): [12 12 12 12 12  6  6  6  6  6  6  6 29 29 29 29 29 29  3  3]
(292,)


## 2. Quantizer training & testing

### 2.0 Usefull function

In [12]:
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import os
from Levenshtein import distance as levenshtein_distance

In [None]:
def augment_audio(audio: torch.Tensor, sr: int, augmentation_type: str, **kwargs) -> torch.Tensor:
    """
    Apply an augmentation.

    Parameters:
      audio (torch.Tensor): The input audio waveform (shape: [channels, samples]).
      sr (int): Sampling rate of the audio.
      augmentation_type (str): The type of augmentation to perform.
                                 Options: "gaussian_noise".
      **kwargs: Additional keyword arguments for specific augmentations.

    Returns:
      torch.Tensor: The augmented audio waveform.
    """
    # Gaussian noise :
    if augmentation_type == "gaussian_noise":
        noise_level = kwargs.get("noise_level", 0.005)
        noise = torch.randn_like(audio)
        augmented_audio = audio + noise_level * noise
        return augmented_audio

    else:
        raise ValueError(f"Unknown augmentation type: {augmentation_type}")

In [None]:
def compute_ctc_loss(
    features_: torch.Tensor,
    perturbed_features_: torch.Tensor,
    E0: callable,
    E1: nn.Module,
    blank_token: int = 0,
    target_lengths: torch.Tensor = None,
    input_lengths: torch.Tensor = None,
    target_threshold: float = 0.4
) -> torch.Tensor:
    """
    Compute the CTC loss as described in the paper :
      1) Compute E0(f(x)) to obtain targets tokens.
      2) Compute E1(f(g(x))) to obtain predictions logits.
      3) Compute CTC alignment between target tokens and predictions.

    Args:
      features_ (torch.Tensor): Input features (shape: [batch_size, time_steps, feat_dim]).
      perturbed_features_ (torch.Tensor): Augmented features (shape: [batch_size, time_steps, feat_dim]).
      E0 (nn.Module): Pretrained quantizer -> outputs discrete token IDs.
      E1 (nn.Module): Trainable quantizer (IN PAPER : MLP) -> outputs logits over discrete tokens.
      blank_token (int): Index for the CTC blank token.
      target_threshold (float): The percentage of the input size used for target (bigger than 0.5 would results in problems).

    Returns:
      torch.Tensor: Scalar CTC loss.
    """
    # 1) Get target tokens from original input
    with torch.no_grad():
        target_tokens = E0(features_)  # [batch_size, target_len]

    # 2) Augment input, then get prediction logits
    prediction_logits = E1(perturbed_features_)    # [batch_size, time_steps, vocab_size]

    # Convert logits to log-probabilities for CTC
    prediction_log_probs = F.log_softmax(prediction_logits, dim=-1)  # [batch_size, time_steps, vocab_size]

    # Prepare lengths for CTC
    #    - input_lengths: how many time steps in the prediction output
    #    - target_lengths: how many tokens in the target
    batch_size, time_steps, vocab_size = prediction_log_probs.shape
    target_len = target_tokens.shape[1]

    if input_lengths is None:
        input_lengths = torch.full(
            size=(batch_size,),
            fill_value=time_steps,
            dtype=torch.long,
            device=prediction_log_probs.device
        )
    if target_lengths is None:
        target_lenghts = [int(input_lengths[i] * target_threshold) for i in range(batch_size)]
        target_lengths = torch.tensor(target_lenghts, dtype=torch.long, device=prediction_log_probs.device)

    target_tokens = [target_tokens[i, :target_lengths[i]] for i in range(batch_size)]
    target_tokens = torch.cat(target_tokens, dim=0)

    assert (input_lengths > target_lengths).all(), "CTC usually requires input lengths > target lengths!"

    # 4) CTC expects [time, batch, vocab] for the log probabilities
    prediction_log_probs = prediction_log_probs.permute(1, 0, 2).contiguous()

    # Flatten target tokens for CTC. They must be a 1D tensor concatenated for the batch,
    # but we also need to pass the correct target_lengths for each sample (see doc but target_lenghts size impact a lot the loss).
    target_tokens = target_tokens.view(-1)

    # 5) Instantiate and compute the CTC loss
    ctc_loss_fn = nn.CTCLoss(blank=blank_token, zero_infinity=True)
    loss = ctc_loss_fn(prediction_log_probs, target_tokens, input_lengths, target_lengths)
    return loss

In [13]:
class MLPQuantizer(nn.Module):
    """Trainable quantizer E1 that outputs logits over discrete tokens."""
    def __init__(self, input_dim: int, vocab_size: int):
        super().__init__()
        self.input_dim = input_dim
        self.vocab_size = vocab_size
        self.layer = nn.Linear(input_dim, vocab_size)

    def forward(self, features):
        """
        Args:
            features: Tensor of shape [batch, time_step, feature]
        Returns:
            logits: Tensor of shape [batch, time_step, vocab_size]
        """
        logits = self.layer(features)
        return logits

    def predict(self, features):
        """
        Args:
            features: Tensor of shape [batch, time_step, feature]
        Returns:
            predictions: Tensor of shape [batch, time_step]
        """
        logits = self(features)
        predictions = logits.argmax(dim=-1)
        return predictions

In [14]:
def Quantizer0(features_: torch.Tensor, num_clusters: int, n_init: int =10) -> torch.Tensor:
    """
    Discretize features using a pretrained KMeans model.

    Args:
        features_ (torch.Tensor): Input features (shape: [batch_size, time_steps, feat_dim]).
        kmeans (KMeans): Pretrained KMeans model.

    Returns:
        torch.Tensor: Discrete representations (shape: [batch_size, time_steps]).
    """
    batch_size = features_.shape[0]
    quantized_ids = torch.zeros(batch_size, features_.shape[1], dtype=torch.long)
    for idx in range(batch_size):
        _features_ = features_[idx].view(-1, features_.shape[-1]).cpu().numpy()
        kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init=n_init)
        kmeans.fit(_features_)
        quantized_ids[idx] = torch.tensor(kmeans.predict(_features_))
    return quantized_ids

In [None]:
# test ctc loss
# Load quantizer E0
vocab_size = 50
E0 = lambda x: Quantizer0(x, vocab_size)

# Load quantizer E1
E1 = MLPQuantizer(input_dim=768, vocab_size=vocab_size)
E1.train()

# Load perturbed features
perturbed_audio = augment_audio(torch.Tensor(dataset[0]["audio"]["array"]), sr=sampling_rate, augmentation_type="gaussian_noise", noise_level=0.5)
perturbed_features = feature_extractor(perturbed_audio, sampling_rate=sampling_rate, return_tensors="pt")['input_values']
perturbed_features = model(perturbed_features)['last_hidden_state']
# compute ctc loss
features = encoder_output
print("Features Shape:", features.shape)

ctc_loss = compute_ctc_loss(features, perturbed_features, E0, E1)
print("CTC Loss:", ctc_loss.item())

E1.eval()
E1.predict(perturbed_features)



Features Shape: torch.Size([1, 292, 768])
CTC Loss: 8.040390014648438


tensor([[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 11, 12, 11, 12, 12, 12,
         40, 11, 40,  0, 11, 12, 11, 11, 12, 12, 19, 19, 12, 17, 12, 12, 12, 17,
         17, 17, 17, 17, 17, 33, 33, 26, 26, 26, 18, 17, 17, 17, 17, 33, 33, 33,
         18, 18, 17, 17, 33, 33, 33, 24, 33, 33, 33, 33, 18, 18, 18, 18, 33, 33,
         33, 18, 18, 18, 34, 44, 44, 18, 44, 18, 18, 18, 18, 34, 34, 34, 34, 34,
         24, 44, 44, 34, 34, 34, 34, 34, 34, 34, 34, 34, 24, 18, 18, 34, 18, 18,
         18, 18, 34, 34, 34, 34, 34, 34, 34, 34, 34, 44, 44, 44, 44, 18, 44, 44,
         44, 29, 44, 44, 44, 44, 34, 44, 44, 44, 44, 18, 34, 34, 44, 44, 44, 44,
         44, 44, 44, 44, 44, 40, 40, 40, 44, 44, 44, 44, 44, 34, 44, 44, 44, 44,
         44, 44, 40, 44, 44, 44, 44, 34, 34, 34, 40, 18, 24, 18, 18, 18, 18, 18,
         18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 34, 34, 34, 29, 29, 44,
         44, 44, 44, 44, 44, 18, 18, 18, 18, 34, 34, 34, 44, 18, 18, 17, 17, 17,
         34, 34, 34, 34, 44,

### 2.1 Dataset preprocessing

In [None]:
from tqdm import tqdm

def preprocess_and_save_features(dataset, encoder, feature_extractor, augmentation_fn, sampling_rate, num_samples=100, save_path="precomputed_features.pt"):

    N = min(num_samples, len(dataset))
    encoded_features = []
    perturbed_encoded_features = []

    for idx in tqdm(range(N), desc="Processing dataset"):
        audio = dataset[idx]["audio"]["array"]

        # Extract features
        features = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")["input_values"]
        features = encoder(features)["last_hidden_state"].squeeze(0)  # Remove batch dim

        encoded_features.append(features.cpu().detach().numpy())  # Convert to NumPy

        # Apply augmentation
        perturbed_audio = augmentation_fn(torch.Tensor(audio), sampling_rate)
        perturbed_features = feature_extractor(perturbed_audio, sampling_rate=sampling_rate, return_tensors="pt")["input_values"]
        perturbed_features = encoder(perturbed_features)["last_hidden_state"].squeeze(0)  # Remove batch dim

        perturbed_encoded_features.append(perturbed_features.cpu().detach().numpy())  # Convert to NumPy

    # Save as a NumPy file or Torch tensor
    torch.save(encoded_features, save_path)
    torch.save(perturbed_encoded_features, save_path.replace(".pt", "_perturbed.pt"))
    print(f"Precomputed features saved at {save_path}")

In [None]:
def pad_batch(batch):
    """
    Pads both the dataset and perturbed dataset in the batch to the same sequence length
    and returns original lengths, padded dataset, and perturbed dataset.
    """
    # Extract original sequence lengths for both datasets
    dataset_tensors, perturbed_tensors = zip(*batch)  # Unzip dataset and perturbed dataset

    # Get the sequence lengths
    lengths = [tensor.shape[0] for tensor in dataset_tensors]

    # Pad both datasets
    padded_dataset = pad_sequence(dataset_tensors, batch_first=True, padding_value=0.0)
    padded_perturbed = pad_sequence(perturbed_tensors, batch_first=True, padding_value=0.0)

    return lengths, padded_dataset, padded_perturbed

In [None]:
## You can always use this routine before anything to get the preprocessed datasets.
# load dataset
num_samples = 10
augmentation_fn = lambda x, sampling_rate: augment_audio(x, sr=sampling_rate, augmentation_type="gaussian_noise", noise_level=0.01)
file_name = "precomputed_features.pt"
if not os.path.exists(file_name):
    preprocess_and_save_features(dataset, model, feature_extractor, augmentation_fn, sampling_rate, num_samples=num_samples, save_path=file_name)

# Load precomputed features
features = torch.load(file_name, weights_only=False)
perturbed_features = torch.load(file_name.replace(".pt", "_perturbed.pt"), weights_only=False)

features = [torch.tensor(f) if isinstance(f, np.ndarray) else f for f in features] # If numpy array then convert to torch tensor
perturbed_features = [torch.tensor(f) if isinstance(f, np.ndarray) else f for f in perturbed_features] # If numpy array then convert to torch tensor

Processing dataset: 100%|██████████| 10/10 [01:09<00:00,  6.98s/it]


Precomputed features saved at precomputed_features.pt


### 2.2 Training

In [None]:
def train_quantizer(E0, E1, dataset, perturbed_dataset, num_epochs=10, batch_size=16, learning_rate=1e-3):
    """
    Train the quantizer E1 using the CTC loss while keeping E0 frozen.

    Args:
        E0: Pretrained quantizer.
        E1: Trainable quantizer.
        dataset: Dataset containing features.
        perturbed_dataset: Dataset containing perturbed features.
        num_epochs: Number of training epochs.
        batch_size: Batch size for training.
        learning_rate: Learning rate for optimizer.

    Returns:
        Trained quantizer E1.
    """

    # Freeze E0 (if it has parameters)
    if hasattr(E0, 'parameters') and any(p.requires_grad for p in E0.parameters()):
        for param in E0.parameters():
            param.requires_grad = False

    # Set E1 to training mode
    E1.train()

    # Define optimizer for E1
    optimizer = optim.Adam(E1.parameters(), lr=learning_rate)

    # Create DataLoader
    dataloader = DataLoader(
        list(zip(dataset, perturbed_dataset)),
        batch_size=batch_size,
        shuffle=True,
        collate_fn=pad_batch
    )

    for epoch in range(num_epochs):
        total_loss = 0.0
        E1.train()
        for lengths, batch, perturbed_batch in dataloader:
            optimizer.zero_grad()

            # Load clean features
            clean_features_ = batch

            # Load perturbed features
            perturbed_features_ = perturbed_batch

            # Compute CTC loss
            loss = compute_ctc_loss(clean_features_, perturbed_features_, E0, E1, input_lengths=torch.LongTensor(lengths), target_threshold=0.35)
            # Backpropagation
            loss.backward() # Changed retain_graph to False to prevent RuntimeError
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch + 1}/{num_epochs} - CTC Loss: {avg_loss:.4f}")

    return E1  # Return trained model

In [None]:
print(len(features), len(perturbed_features))

10 10


In [None]:
# train quantizer
vocab_size = 50
# Load quantizer E0
E0 = lambda x: Quantizer0(x, vocab_size)

# Load quantizer E1
E1 = MLPQuantizer(input_dim=768, vocab_size=50)

E1 = train_quantizer(E0, E1, features, perturbed_features, num_epochs=15, batch_size=5, learning_rate=1e-3)
# save model E1
torch.save(E1.state_dict(), "E1.pth")

Epoch 1/15 - CTC Loss: 9.1688
Epoch 2/15 - CTC Loss: 8.8909
Epoch 3/15 - CTC Loss: 8.5497
Epoch 4/15 - CTC Loss: 8.2815
Epoch 5/15 - CTC Loss: 7.9384
Epoch 6/15 - CTC Loss: 7.6550
Epoch 7/15 - CTC Loss: 7.2955
Epoch 8/15 - CTC Loss: 7.0102
Epoch 9/15 - CTC Loss: 6.5961
Epoch 10/15 - CTC Loss: 6.3419
Epoch 11/15 - CTC Loss: 6.0299
Epoch 12/15 - CTC Loss: 5.7494
Epoch 13/15 - CTC Loss: 5.5464
Epoch 14/15 - CTC Loss: 5.3265
Epoch 15/15 - CTC Loss: 5.1351


### 2.3 UED computation

In [None]:
# load E1
E1 = MLPQuantizer(input_dim=768, vocab_size=50)
E1.load_state_dict(torch.load("E1.pth"))

<All keys matched successfully>

In [None]:
def compute_ued(dataset, perturbed_dataset, quantizer):
    """
    Compute the Unit Edit Distance (UED) metric.

    Args:
        dataset: Dataset containing clean features.
        perturbed_dataset: Dataset containing perturbed features.
        quantizer: Quantizer model.

    Returns:
        The average normalized Levenshtein distance across the dataset.
    """
    total_distance = 0
    total_frames = 0

    for (x, augmented_x) in zip(dataset, perturbed_dataset):
        x = x.unsqueeze(0)
        augmented_x = augmented_x.unsqueeze(0)
        quantized_x = quantizer.predict(x).flatten()  # Convert to token sequence
        quantized_aug_x = quantizer.predict(augmented_x).flatten()
        # Compute Levenshtein distance
        lev_dist = levenshtein_distance(quantized_x.tolist(), quantized_aug_x.tolist())
        total_distance += lev_dist  # Normalize by sequence length
        total_frames += 1  # Track the number of samples

    return total_distance / total_frames

In [None]:
class E0_quantizer():
    def __init__(self, kmeans_quantizer_fn):
        self.quantizer = kmeans_quantizer_fn

    def predict(self, x):
        return self.quantizer(x)

In [None]:
# compute ued E0:
E0_fn = lambda x: Quantizer0(x, vocab_size)
E0 = E0_quantizer(E0_fn)
ued = compute_ued(features, perturbed_features, E0)
print("UED:", ued)

UED: 524.4


In [None]:
# compute ued E1:
ued = compute_ued(features, perturbed_features, E1)
print("UED:", ued)

UED: 1.6


In [None]:
# Make test dataset
num_samples = 10
test_dataset = [dataset[num_samples -1 + idx] for idx in range(num_samples)]
augmentation_fn = lambda x, sampling_rate: augment_audio(x, sr=sampling_rate, augmentation_type="gaussian_noise", noise_level=0.01)
file_name = "test_precomputed_features.pt"
if not os.path.exists(file_name):
    preprocess_and_save_features(dataset, model, feature_extractor, augmentation_fn, sampling_rate, num_samples=num_samples, save_path=file_name)

# Load precomputed features
test_features = torch.load(file_name, weights_only=False)
test_perturbed_features = torch.load(file_name.replace(".pt", "_perturbed.pt"), weights_only=False)

test_features = [torch.tensor(f) if isinstance(f, np.ndarray) else f for f in test_features] # If numpy array then convert to torch tensor
test_perturbed_features = [torch.tensor(f) if isinstance(f, np.ndarray) else f for f in test_perturbed_features] # If numpy array then convert to torch tensor

Processing dataset: 100%|██████████| 10/10 [00:11<00:00,  1.18s/it]


Precomputed features saved at test_precomputed_features.pt


In [None]:
# compute ued E0:
E0_fn = lambda x: Quantizer0(x, vocab_size)
E0 = E0_quantizer(E0_fn)
ued = compute_ued(test_features, test_perturbed_features, E0)
print("UED:", ued)

# compute ued E1:
ued = compute_ued(test_features, test_perturbed_features, E1)
print("UED:", ued)

UED: 517.5
UED: 1.5


### 2.4 Some more basic functions

In [None]:
import random

def random_augment_audio(audio: torch.Tensor, sr: int, augmenter_audio: callable, augmentation_type_list: list, **kwargs) -> torch.Tensor:
    """
    Apply a random augmentation.

    augmenter_audio : your augmenter function (see mine in 2.0 for reference)
    augmentation_type : A list of augmentation functions to apply.
    **kwargs: Additional keyword arguments for specific augmentations

    /!\ remember that when you define the augmentation_fn for training use EXAMPLE :
    augmentation_type_list = ["gaussian_noise", ...]
    augmentation_fn = lambda x, sr: random_augment_audio(x, sr, augmentation_type_list, **kwargs)
    """
    # pick random augmentation
    augmentation_type = random.choice(augmentation_type_list)
    return augmenter_audio(audio, sr, augmentation_type, **kwargs)


In [None]:
def compute_ued_between_quantizer(dataset, perturbed_dataset, tested_quantizer, target_quantizer):
    """
    Compute the Unit Edit Distance (UED) metric.

    Args:
        dataset: Dataset containing clean features.
        perturbed_dataset: Dataset containing perturbed features.
        tested_quantizer: Quantizer model to be tested on perturbed data.
        target_quantizer: Quantizer model to use for the test (be sure it has been trained on same data if not E0).

    Returns:
        The average normalized Levenshtein distance across the dataset.
    """
    total_distance = 0
    total_frames = 0

    for (x, augmented_x) in zip(dataset, perturbed_dataset):
        x = x.unsqueeze(0)
        augmented_x = augmented_x.unsqueeze(0)
        quantized_x = target_quantizer.predict(x).flatten()  # Convert to token sequence
        quantized_aug_x = tested_quantizer.predict(augmented_x).flatten()
        # Compute Levenshtein distance
        lev_dist = levenshtein_distance(quantized_x.tolist(), quantized_aug_x.tolist())
        total_distance += lev_dist  # Normalize by sequence length
        total_frames += 1  # Track the number of samples

    return total_distance / total_frames

In [None]:
# compute ued E1 versus E0:
vocab_size = 50
E0_fn = lambda x: Quantizer0(x, vocab_size)
E0 = E0_quantizer(E0_fn)
E1 = MLPQuantizer(input_dim=768, vocab_size=50)
E1.load_state_dict(torch.load("E1.pth"))
ued = compute_ued_between_quantizer(features, perturbed_features, E1, E0)
print("UED:", ued)

UED: 535.8


### 2.5 ABX metric

In [32]:
class FullModel(nn.Module):
    def __init__(self, features_extractor, model, quantizer, sampling_rate):
        super().__init__()
        self.features_extractor = features_extractor
        self.model = model
        self.quantizer = quantizer
        self.sampling_rate = sampling_rate

    def forward(self, x):
        features = self.features_extractor(x, sampling_rate=self.sampling_rate, return_tensors="pt")['input_values']
        features = self.model(features)['last_hidden_state']
        token = self.quantizer.predict(features.unsqueeze(0)).flatten()
        return token

In [33]:
import random
import numpy as np

# ------------------------- load our model

E1 = MLPQuantizer(input_dim=768, vocab_size=50)
E1.load_state_dict(torch.load("E1.pth"))

full_model = FullModel(feature_extractor, model, E1, sampling_rate)

# ------------------------- I advise you to stream the dataset.
dataset = load_dataset('gilkeyio/librispeech-alignments', split='dev_clean', streaming=True)
for sample in dataset.take(1):
    if "audio" in sample:
        print("Sampling rate:", sample["audio"]["sampling_rate"])
        sampling_rate = sample["audio"]["sampling_rate"]
    else:
        print("No audio field in this sample.")
# ------------------------- Build a phoneme dictionary.
phoneme_dict = {}
num_examples = 100
counter = 0
for entry in dataset:
    phonemes = entry['phonemes']
    for ph in phonemes:
        phoneme = ph["phoneme"]
        if phoneme not in phoneme_dict:
            phoneme_dict[phoneme] = []
        end_time = ph["end"]
        start_time = ph["start"]
        # time are of the form 0.21, 0.22, etc.
        audio = entry["audio"]["array"][int(start_time * sampling_rate):int(end_time * sampling_rate)]
        phoneme_dict[phoneme].append(audio)
    counter += 1
    if counter >= num_examples:
        break

# ------------------------- Create ABX triplets.
# For each triplet:
#  - A and X come from the same phoneme class.
#  - B comes from a different phoneme class.
def create_abx_triplets(phoneme_dict, num_triplets=100):
    triplets = []
    phoneme_classes = list(phoneme_dict.keys())
    for _ in range(num_triplets):
        # Choose a phoneme class that has at least 2 of these for A and X.
        valid_classes = [ph for ph in phoneme_classes if len(phoneme_dict[ph]) >= 2]
        if not valid_classes:
            break
        class_A = random.choice(valid_classes)
        A_X = random.sample(phoneme_dict[class_A], 2)
        A_token, X_token = A_X[0], A_X[1]
        # Choose a different class for B (needs at least one of these)
        valid_B_classes = [ph for ph in phoneme_classes if ph != class_A and len(phoneme_dict[ph]) >= 1]
        if not valid_B_classes:
            break
        class_B = random.choice(valid_B_classes)
        B_token = random.choice(phoneme_dict[class_B])
        triplets.append((A_token, B_token, X_token))
    return triplets

triplets = create_abx_triplets(phoneme_dict, num_triplets=100)
print(f"Created {len(triplets)} ABX triplets.")

# ------------------------- Define a distance function. Here we use Levenhstein.
def distance_of_lev(output, target):
    return levenshtein_distance(output.tolist(), target.tolist())

# ------------------------- Compute ABX scores.
def compute_abx_score(triplets, model, distance = distance_of_lev):
    errors = 0
    total = len(triplets)
    for A_ph, B_ph, X_ph in triplets:
        A_token = model(A_ph)
        B_token = model(B_ph)
        X_token = model(X_ph)
        d_AX = distance(A_token, X_token)
        d_BX = distance(B_token, X_token)
        # For ABX, if X is closer to A than to B, it's correct.
        if d_AX > d_BX:
            errors += 1
    return errors / total

# -------------------------

abx_error_rate = compute_abx_score(triplets, full_model)
print("ABX error rate:", abx_error_rate)


Resolving data files:   0%|          | 0/49 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/66 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/49 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/66 [00:00<?, ?it/s]

Sampling rate: 16000
Created 100 ABX triplets.
ABX error rate: 0.3


## Augmentation

#### Useful functions for auditive tests

In [None]:
def play_audio(audio, sr=16000, title="audio"):
    audio_numpy = audio.squeeze().numpy()
    print(f"Listen : {title}")
    display(Audio(audio_numpy, rate=sr))

def visualize_and_play(audio, augmented_audio, sr=16000):
    play_audio(audio, sr, "original audio")
    play_audio(augmented_audio, sr, "augmented audio")

#### Audio effect of adding noise

In [None]:
# check gaussian noise
audio=torch.tensor([1,2,3,4,5,6,7,8,9], dtype=torch.float)

print(" Audio before transform : ")
print(audio)

noise=torch.randn_like(audio)

audio_with_noise=audio+noise

print(" Audio after adding noise : ")
print(audio_with_noise)

print("Audio after multiplying everything by 0.001 : ")
print(0.01*audio)

 Audio before transform : 
tensor([1., 2., 3., 4., 5., 6., 7., 8., 9.])
 Audio after adding noise : 
tensor([1.3143, 1.3834, 1.5881, 4.5715, 6.2245, 7.9845, 6.8187, 9.6168, 9.8978])
Audio after multiplying everything by 0.001 : 
tensor([0.0100, 0.0200, 0.0300, 0.0400, 0.0500, 0.0600, 0.0700, 0.0800, 0.0900])


In [None]:
audio=torch.tensor(dataset[0]["audio"]["array"])
print(f"shape of the 'audio' signal: {audio.shape}")
noise = torch.randn_like(audio)
noisy = audio+0.01*noise # put >=0.01
visualize_and_play(audio, noisy, sr=16000)

shape of the 'audio' signal: torch.Size([93680])
Listen : original audio


Listen : augmented audio


#### Audio effect of changing speed

In [None]:
# check gaussian noise
audio=torch.tensor([1,2,3,4,5,6,7,8,9], dtype=torch.float)

print(" Audio before transform : ")
print(audio)

# change speed
def speed_change(audio, rate):
    indices = torch.arange(0, audio.shape[0], rate)
    indices = indices.long()
    return audio[indices]

faster_audio = speed_change(audio, 0.5)
print(" Audio after speed change : ")
print(faster_audio)

 Audio before transform : 
tensor([1., 2., 3., 4., 5., 6., 7., 8., 9.])
 Audio after speed change : 
tensor([1., 1., 2., 2., 3., 3., 4., 4., 5., 5., 6., 6., 7., 7., 8., 8., 9., 9.])


In [None]:
audio=torch.tensor(dataset[1]["audio"]["array"])
print(f"shape of the 'audio' signal: {audio.shape}")
audio_speed_change=speed_change(audio, 1.2)
visualize_and_play(audio, audio_speed_change, sr=16000)

shape of the 'audio' signal: torch.Size([77040])
Listen : original audio


Listen : augmented audio


#### Audio effect of changing "pitch"

In [None]:
audio=torch.tensor([1,2,3,4,5,6,7,8,9], dtype=torch.float)
print(" Audio before transform : ")
print(audio)

# change pitch high
def pitch_shift(audio, n_steps):
    factor = 2**(n_steps/12)
    orig_len = audio.shape[0]
    new_len = int(orig_len/factor)
    indices = torch.arange(0, orig_len, step=factor).long()
    indices = indices[:min(new_len, len(indices))]
    return audio[indices]

n_steps=2
pitch_shifted_audio = pitch_shift(audio, n_steps)
print("Audio with modified pitch:")
print(pitch_shifted_audio)

 Audio before transform : 
tensor([1., 2., 3., 4., 5., 6., 7., 8., 9.])
Audio with modified pitch:
tensor([1., 2., 3., 4., 5., 6., 7., 8.])


In [None]:
audio=torch.tensor(dataset[3]["audio"]["array"])
print(f"shape of the 'audio' signal: {audio.shape}")
audio_pitch_change=pitch_shift(audio, 10)
visualize_and_play(audio, audio_pitch_change, sr=16000)

shape of the 'audio' signal: torch.Size([158400])
Listen : original audio


Listen : augmented audio


#### Audio effect of clipping

In [None]:
audio=torch.tensor([1,2,3,4,5,6,7,8,9], dtype=torch.float)
print(" Audio before transform : ")
print(audio)

clip_factor = 0.8
max_val = audio.abs().max()
threshold = clip_factor*max_val
augmented_audio = torch.clamp(audio, -threshold, threshold)
print("audio after clipping :")
print(augmented_audio)

 Audio before transform : 
tensor([1., 2., 3., 4., 5., 6., 7., 8., 9.])
audio after clipping :
tensor([1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 7.2000, 7.2000])


In [None]:
audio=torch.tensor(dataset[4]["audio"]["array"])
print(f"shape of the 'audio' signal: {audio.shape}")
threshold = 0.01*audio.abs().max() # <0.1
clipped_audio = torch.clamp(audio,-threshold,threshold)
visualize_and_play(audio, clipped_audio, sr=16000)

Output hidden; open in https://colab.research.google.com to view.

#### Audio effect of filter (lowpass)

In [None]:
audio=torch.tensor(dataset[5]["audio"]["array"])
print(f"shape of the 'audio' signal: {audio.shape}")
audio_np = audio.numpy()
fft_signal = np.fft.rfft(audio_np)
freqs = np.fft.rfftfreq(len(audio_np), 1/16000)
mask = freqs<=1000
filtered_fft = fft_signal*mask
filtered_signal = np.fft.irfft(filtered_fft)
audio_lowpass=torch.tensor(filtered_signal, dtype=audio.dtype)
visualize_and_play(audio, audio_lowpass, sr=16000)

shape of the 'audio' signal: torch.Size([144160])
Listen : original audio


Listen : augmented audio


#### Audio effect of filter (bandpass)

In [None]:
audio=torch.tensor(dataset[6]["audio"]["array"])
print(f"shape of the 'audio' signal: {audio.shape}")
audio_np = audio.numpy()
fft_signal = np.fft.rfft(audio_np)
freqs = np.fft.rfftfreq(len(audio_np), 1/16000)
mask = (freqs >= 500) & (freqs <= 1500)
filtered_fft = fft_signal*mask
filtered_signal = np.fft.irfft(filtered_fft)
audio_bandpass=torch.tensor(filtered_signal, dtype=audio.dtype)
visualize_and_play(audio, audio_bandpass, sr=16000)

shape of the 'audio' signal: torch.Size([90240])
Listen : original audio


Listen : augmented audio


#### Audio effect of filter (highpass)

In [None]:
audio=torch.tensor(dataset[7]["audio"]["array"])
print(f"shape of the 'audio' signal: {audio.shape}")
audio_np = audio.numpy()
fft_signal = np.fft.rfft(audio_np)
freqs = np.fft.rfftfreq(len(audio_np), 1/16000)
mask = freqs>=500
filtered_fft = fft_signal*mask
filtered_signal = np.fft.irfft(filtered_fft)
audio_highpass=torch.tensor(filtered_signal, dtype=audio.dtype)
visualize_and_play(audio, audio_highpass, sr=16000)

shape of the 'audio' signal: torch.Size([147840])
Listen : original audio


Listen : augmented audio


#### Audio effects of adding many little bips

In [None]:
def bip(audio):
  audio_bipped=audio.clone()
  max=audio.abs().max()
  for k in range(audio.shape[0]):
    if k%2==0:
      audio_bipped[k]=max
  return audio_bipped

In [None]:
audio=torch.tensor(dataset[8]["audio"]["array"])
print(f"shape of the 'audio' signal: {audio.shape}")
audio_bip= bip(audio)
visualize_and_play(audio, audio_bip, sr=16000)

shape of the 'audio' signal: torch.Size([81920])
Listen : original audio


Listen : augmented audio


#### Audio effects of adding some big bips

In [None]:
def big_bip(audio):
  audio_bipped=audio.clone()
  max=audio.abs().max()
  for k in range(audio.shape[0]):
    if k%10000==0 and k+100<=audio.shape[0]:
      audio_bipped[k:k+100]=torch.tensor([max]*100)
  return audio_bipped

In [None]:
audio=torch.tensor(dataset[9]["audio"]["array"])
print(f"shape of the 'audio' signal: {audio.shape}")
audio_big_bip= big_bip(audio)
visualize_and_play(audio, audio_big_bip, sr=16000)

shape of the 'audio' signal: torch.Size([292640])
Listen : original audio


Listen : augmented audio


#### Audio effect of adding an echo

In [None]:
import matplotlib.pyplot as plt
from IPython.display import Audio, display

def echo(audio, nb_echos=5):
    nb_samples=audio.shape[0]
    '''
    plt.plot(audio)
    plt.title("audio")
    plt.show()
    '''
    echo_audio=audio.clone()
    for i in range(nb_echos):
      '''
      plt.plot((1/(i+1)**2)*audio[:(nb_samples-(1000*i))])
      plt.title(f"echo {i}")
      plt.show()
      '''
      echo_audio[(4000*i):]+=(1/(i+1)**2)*audio[:(nb_samples-(4000*i))]
    return echo_audio

audio=torch.tensor(dataset[10]["audio"]["array"])
print(f"shape of the 'audio' signal: {audio.shape}")
echo_audio=echo(audio, nb_echos=10)
visualize_and_play(audio, echo_audio, sr=16000)

shape of the 'audio' signal: torch.Size([89600])
Listen : original audio


Listen : augmented audio


#### Function augment_audio

In [None]:
def augment_audio(audio: torch.Tensor, sr: int, augmentation_type: str, **kwargs) -> torch.Tensor:
    """
    Apply an augmentation.

    Parameters:
      audio (torch.Tensor): The input audio waveform (shape: [channels, samples]).
      sr (int): Sampling rate of the audio.
      augmentation_type (str): The type of augmentation to perform.
                                 Options: "gaussian_noise", "time_stretch",
                                          "pitch_shift", "clipping", "lowpass", "bandpass", "highpass", "little_bips",
                                          "big_bips", "echo"
      **kwargs: Additional keyword arguments for specific augmentations.

    Returns:
      torch.Tensor: The augmented audio waveform.
    """
    # Audio effect of adding noise
    if augmentation_type == "gaussian_noise":
        noise_level = kwargs.get("noise_level", 0.01)
        noise = torch.randn_like(audio)
        noisy = audio+0.01*noise # put >=0.01
        return augmented_audio

    # Audio effect of changing speed
    elif augmentation_type == "time_stretch":
        rate = kwargs.get("rate", 1.2)  # >1 speeds up, <1 slows down.
        return speed_change(audio, rate)

    # Audio effect of changing "pitch"
    elif augmentation_type == "pitch_shift":
        n_steps = kwargs.get("n_steps", 2)
        return pitch_shift(audio, n_steps)

    # Audio effect of clipping
    elif augmentation_type == "clipping":
        clip_factor = kwargs.get("clip_factor", 0.8)
        threshold = 0.01*audio.abs().max() # <0.1
        return torch.clamp(audio,-threshold,threshold)

    # Audio effect of filter (lowpass)
    elif augmentation_type == "lowpass":
      freqlim = kwargs.get("freqlim", 1000)
      audio_np = audio.numpy()
      fft_signal = np.fft.rfft(audio_np)
      freqs = np.fft.rfftfreq(len(audio_np), 1/16000)
      mask = freqs<=freqlim
      filtered_fft = fft_signal*mask
      filtered_signal = np.fft.irfft(filtered_fft)
      audio_lowpass=torch.tensor(filtered_signal, dtype=audio.dtype)
      return audio_lowpass

    # Audio effect of filter (bandpass)
    elif augmentation_type == "bandpass":
      freqlim1 = kwargs.get("freqlim1", 500)
      freqlim2 = kwargs.get("freqlim2", 1500)
      audio_np = audio.numpy()
      fft_signal = np.fft.rfft(audio_np)
      freqs = np.fft.rfftfreq(len(audio_np), 1/16000)
      mask = (freqs >= freqlim1) & (freqs <= freqlim2)
      filtered_fft = fft_signal*mask
      filtered_signal = np.fft.irfft(filtered_fft)
      audio_bandpass=torch.tensor(filtered_signal, dtype=audio.dtype)
      return audio_bandpass

    # Audio effect of filter (highpass)
    elif augmentation_type == "highpass":
      freqlim = kwargs.get("freqlim", 500)
      audio_np = audio.numpy()
      fft_signal = np.fft.rfft(audio_np)
      freqs = np.fft.rfftfreq(len(audio_np), 1/16000)
      mask = freqs>=freqlim
      filtered_fft = fft_signal*mask
      filtered_signal = np.fft.irfft(filtered_fft)
      audio_highpass=torch.tensor(filtered_signal, dtype=audio.dtype)
      return audio_highpass

    # Audio effects of adding many little bips
    elif augmentation_type == "little_bips":
      audio_bip=bip(audio)
      return audio_bip


    # Audio effects of adding some big bips
    elif augmentation_type == "big_bips":
      audio_big_bip= big_bip(audio)
      return audio_big_bip


    # Audio effect of adding an echo
    elif augmentation_type == "echo":
      nb_echos = kwargs.get("nb_echos", 5)
      echo_audio=echo(audio, nb_echos)
      return echo_audio

    else:
        raise ValueError(f"Unknown augmentation type: {augmentation_type}")