# 1. Preliminaries

In [None]:
%%capture
!pip install datasets
!git clone https://github.com/microsoft/MS-SNSD.git

In [None]:
import torch
import torchaudio
from transformers import AutoProcessor, HubertModel, Wav2Vec2FeatureExtractor
import os
import numpy as np
import librosa
from tqdm import tqdm
from IPython.display import Audio, display
import pickle
from typing import Optional, Tuple, Union
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from sklearn.cluster import MiniBatchKMeans, KMeans

data_dir = './data'
os.makedirs(data_dir, exist_ok=True)

In [None]:
librispeech_dev = torchaudio.datasets.LIBRISPEECH("./data", url="dev-clean", download=True)
librispeech_train = torchaudio.datasets.LIBRISPEECH("./data", url="train-clean-100", download=True)
librispeech_test = torchaudio.datasets.LIBRISPEECH("./data", url="test-clean", download=True)

  3%|▎         | 8.49M/322M [00:02<00:46, 7.09MB/s]

In [None]:
!gdown 1wVq5bdLcEjR-qKm2UOl1jCwD7hEAQsDj
!gdown 1QHsrC52r8lhsHBPIy9Din0ir6CWM9nAf
!gdown 1FzwfPHgvZoWSR4iwHdUX_SsuujCSib8o

In [None]:
name = "facebook/hubert-base-ls960"
try:
  processor = AutoProcessor.from_pretrained(name)
  model = HubertModel.from_pretrained(name)
except:
  processor = Wav2Vec2FeatureExtractor.from_pretrained(name)
  model = HubertModel.from_pretrained(name)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# 2.Quantization

## 2.1 Quantize noise and speech separately

In [None]:
class ds_noise_mssnsd(torch.utils.data.Dataset):
  def __init__(self, noise_dir, processor, factor = 1):
    self.processor = processor
    self.noise_dir = noise_dir
    self.factor = factor
    file_paths = []
    for root, dirs, files in os.walk(noise_dir):
        for file in files:
            file_path = os.path.join(root, file)
            if file_path.endswith(".wav"):
              file_paths.append(file_path)

    self.file_paths = file_paths

  def __len__(self):
    return len(self.file_paths)

  def __getitem__(self, idx):
    wave_file = self.file_paths[idx]
    y, sample_rate = librosa.load(wave_file, sr=16000)
    y = y * self.factor
    encoded = self.processor(y[:200000], return_tensors="pt", sampling_rate = sample_rate)
    for k in encoded: encoded[k] = encoded[k].squeeze()
    return encoded

class ds_signal_mssnsd(torch.utils.data.Dataset):
  def __init__(self, data, processor, leng = None):
    self.data = data
    self.processor = processor
    self.leng = leng

  def __len__(self):
    if self.leng is not None:
      return self.leng
    else:
      return len(self.data)

  def __getitem__(self, idx):
    clean_speech = self.data[idx][0].squeeze()
    encoded = processor(clean_speech[:200000], return_tensors="pt", sampling_rate = 16000)
    for k in encoded: encoded[k] = encoded[k].squeeze()
    return encoded

In [None]:
emb_noise = []
emb_signal = []
data_noise = ds_noise_mssnsd("/content/MS-SNSD/noise_train/", processor, factor = 1)
data_signal = ds_signal_mssnsd(librispeech_dev, processor, leng = 500)
loader_noise = torch.utils.data.DataLoader(data_noise, batch_size=1, shuffle=False)
loader_signal = torch.utils.data.DataLoader(data_signal, batch_size=1, shuffle=False)
tbar_noise = tqdm(loader_noise, position=0, leave=True)
tbar_signal = tqdm(loader_signal, position=0, leave=True)
count = 0
for batch in tbar_noise:
  encoded = model(batch.input_values.to(device), output_hidden_states = True)
  embs = encoded["hidden_states"][6].detach().cpu().numpy().tolist()
  for e in embs: emb_noise.extend(e)
for batch in tbar_signal:
  encoded = model(batch.input_values.to(device), output_hidden_states = True)
  embs = encoded["hidden_states"][6].detach().cpu().numpy().tolist()
  for e in embs: emb_signal.extend(e)

100%|██████████| 128/128 [00:08<00:00, 14.94it/s]
100%|██████████| 500/500 [00:32<00:00, 15.18it/s]


In [None]:
# prompt: write a kmeans algo with cluster size. = 100 and 500 respectively on emb_noise and emb_signal (one model for each cluster size, each set of embedding). save the trained kmeans model

from sklearn.cluster import KMeans
# Kmeans for noise
# kmeans_noise_100 = KMeans(n_clusters=100, random_state=0).fit(np.array(emb_noise))
kmeans_noise_500 = KMeans(n_clusters=500, random_state=0).fit(np.array(emb_noise))
# # Kmeans for signal
# kmeans_signal_100 = KMeans(n_clusters=100, random_state=0).fit(np.array(emb_signal))
kmeans_signal_500 = KMeans(n_clusters=500, random_state=0).fit(np.array(emb_signal))





KeyboardInterrupt: 

Exception ignored in: 'sklearn.cluster._k_means_common._relocate_empty_clusters_dense'
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/numpy/core/multiarray.py", line 346, in where
    @array_function_from_c_func_and_dispatcher(_multiarray_umath.where)
KeyboardInterrupt: 


KeyboardInterrupt: 

In [None]:
kmeans_noise_500 = MiniBatchKMeans(n_clusters=100, random_state=0, n_init = 10).fit(np.array(emb_noise))



In [None]:
with open('kmeans_noise_500.pkl', 'wb') as file:
    pickle.dump(kmeans_noise_500, file)
with open('kmeans_signal_500.pkl', 'wb') as file:
    pickle.dump(kmeans_signal_500, file)


In [None]:
with open('kmeans_noise_500.pkl', 'rb') as file:
    kmeans_noise_500 = pickle.load(file)
with open('kmeans_signal_500.pkl', 'rb') as file:
    kmeans_signal_500 = pickle.load(file)

In [None]:
data_train = ds_signal_mssnsd(librispeech_train, processor)
data_test = ds_signal_mssnsd(librispeech_test, processor)
data_noise = ds_noise_mssnsd("/content/MS-SNSD/noise_train/", processor, factor = 1)
loader_train = torch.utils.data.DataLoader(data_train, batch_size=1, shuffle=False)
loader_test = torch.utils.data.DataLoader(data_test, batch_size=1, shuffle=False)
loader_noise = torch.utils.data.DataLoader(data_noise, batch_size=1, shuffle=False)
tbar_train = tqdm(loader_train, position=0, leave=True)
tbar_test = tqdm(loader_test, position=0, leave=True)
tbar_noise = tqdm(loader_noise, position=0, leave=True)
tag_train = []
tag_test = []
tag_noise = []
for batch in tbar_noise:
  encoded = model(batch.input_values.to(device), output_hidden_states = True)
  embs = encoded["hidden_states"][6].detach().cpu().numpy().squeeze()
  tag_tmp = kmeans_noise_500.predict(embs.astype(float))
  tag_noise.append(tag_tmp)
for batch in tbar_train:
  encoded = model(batch.input_values.to(device), output_hidden_states = True)
  embs = encoded["hidden_states"][6].detach().cpu().numpy().squeeze()
  tag_tmp = kmeans_signal_500.predict(embs.astype(float))
  tag_train.append(tag_tmp)
for batch in tbar_test:
  encoded = model(batch.input_values.to(device), output_hidden_states = True)
  embs = encoded["hidden_states"][6].detach().cpu().numpy().squeeze()
  tag_tmp = kmeans_signal_500.predict(embs.astype(float))
  tag_test.append(tag_tmp)



100%|██████████| 128/128 [00:04<00:00, 26.32it/s]
100%|██████████| 28539/28539 [22:05<00:00, 21.53it/s]
100%|██████████| 2620/2620 [23:56<00:00,  1.82it/s]


In [None]:
kmeans_labels = {
    "train": tag_train,
    "test": tag_test,
    "noise": tag_noise
}
with open('kmeans_labels_1_2.pkl', 'wb') as file:
    pickle.dump(kmeans_labels, file)

In [None]:
with open('kmeans_noise_100.pkl', 'wb') as file:
    pickle.dump(kmeans_noise_500, file)

In [None]:
from google.colab import files
files.download('/content/kmeans_labels_1_2.pkl')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## 2.2 Quantize noise, signal, and noise + signal together

In [None]:
print(len(emb_noise))

79844


In [None]:
print(len(emb_signal))

164562


In [None]:
all_emb = []
for e in emb_noise: all_emb.append(e)
for e in emb_signal: #[:int(len(emb_signal) * 2/ 3)]:
  all_emb.append(e)

from sklearn.cluster import KMeans
# Kmeans for noise
kmeans_all_500 = MiniBatchKMeans(n_clusters=600, random_state=1, n_init = 10).fit(np.array(all_emb))

with open('kmeans_all_500.pkl', 'wb') as file:
    pickle.dump(kmeans_all_500, file)

data_train = ds_signal_mssnsd(librispeech_train, processor)
data_test = ds_signal_mssnsd(librispeech_test, processor)
data_noise = ds_noise_mssnsd("/content/MS-SNSD/noise_train/", processor, factor = 1)
loader_train = torch.utils.data.DataLoader(data_train, batch_size=1, shuffle=False)
loader_test = torch.utils.data.DataLoader(data_test, batch_size=1, shuffle=False)
loader_noise = torch.utils.data.DataLoader(data_noise, batch_size=1, shuffle=False)
tbar_train = tqdm(loader_train, position=0, leave=True)
tbar_test = tqdm(loader_test, position=0, leave=True)
tbar_noise = tqdm(loader_noise, position=0, leave=True)
tag_train = []
tag_test = []
tag_noise = []
for batch in tbar_noise:
  encoded = model(batch.input_values.to(device), output_hidden_states = True)
  embs = encoded["hidden_states"][6].detach().cpu().numpy().squeeze()
  tag_tmp = kmeans_all_500.predict(embs.astype(float))
  tag_noise.append(tag_tmp)
for batch in tbar_train:
  encoded = model(batch.input_values.to(device), output_hidden_states = True)
  embs = encoded["hidden_states"][6].detach().cpu().numpy().squeeze()
  tag_tmp = kmeans_all_500.predict(embs.astype(float))
  tag_train.append(tag_tmp)
for batch in tbar_test:
  encoded = model(batch.input_values.to(device), output_hidden_states = True)
  embs = encoded["hidden_states"][6].detach().cpu().numpy().squeeze()
  tag_tmp = kmeans_all_500.predict(embs.astype(float))
  tag_test.append(tag_tmp)

kmeans_labels = {
    "train": tag_train,
    "test": tag_test,
    "noise": tag_noise
}
with open('kmeans_labels_2.pkl', 'wb') as file:
    pickle.dump(kmeans_labels, file)

  0%|          | 0/2620 [00:39<?, ?it/s]
100%|██████████| 128/128 [00:05<00:00, 23.49it/s]
100%|██████████| 28539/28539 [22:56<00:00, 20.73it/s]
100%|██████████| 2620/2620 [24:52<00:00,  1.76it/s]


# 3. Training

In [None]:
def _compute_mask_indices(
    shape: Tuple[int, int],
    mask_prob: float,
    mask_length: int,
    attention_mask: Optional[torch.LongTensor] = None,
    min_masks: int = 0,
) -> np.ndarray:
    """
    Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
    ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
    CPU as part of the preprocessing during training.

    Args:
        shape: The shape for which to compute masks. This should be of a tuple of size 2 where
               the first element is the batch size and the second element is the length of the axis to span.
        mask_prob:  The percentage of the whole axis (between 0 and 1) which will be masked. The number of
                    independently generated mask spans of length `mask_length` is computed by
                    `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
                    actual percentage will be smaller.
        mask_length: size of the mask
        min_masks: minimum number of masked spans
        attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
                        each batch dimension.
    """
    batch_size, sequence_length = shape

    if mask_length < 1:
        raise ValueError("`mask_length` has to be bigger than 0.")

    if mask_length > sequence_length:
        raise ValueError(
            f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
            f" and `sequence_length`: {sequence_length}`"
        )

    # epsilon is used for probabilistic rounding
    epsilon = np.random.rand(1).item()

    def compute_num_masked_span(input_length):
        """Given input length, compute how many spans should be masked"""
        num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
        num_masked_span = max(num_masked_span, min_masks)

        # make sure num masked span <= sequence_length
        if num_masked_span * mask_length > sequence_length:
            num_masked_span = sequence_length // mask_length

        # make sure num_masked span is also <= input_length - (mask_length - 1)
        if input_length - (mask_length - 1) < num_masked_span:
            num_masked_span = max(input_length - (mask_length - 1), 0)

        return num_masked_span

    # compute number of masked spans in batch
    input_lengths = (
        attention_mask.sum(-1).detach().tolist()
        if attention_mask is not None
        else [sequence_length for _ in range(batch_size)]
    )

    # SpecAugment mask to fill
    spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
    spec_aug_mask_idxs = []

    max_num_masked_span = compute_num_masked_span(sequence_length)

    if max_num_masked_span == 0:
        return spec_aug_mask

    for input_length in input_lengths:
        # compute num of masked spans for this input
        num_masked_span = compute_num_masked_span(input_length)

        # get random indices to mask
        spec_aug_mask_idx = np.random.choice(
            np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
        )

        # pick first sampled index that will serve as a dummy index to pad vector
        # to ensure same dimension for all batches due to probabilistic rounding
        # Picking first sample just pads those vectors twice.
        if len(spec_aug_mask_idx) == 0:
            # this case can only happen if `input_length` is strictly smaller then
            # `sequence_length` in which case the last token has to be a padding
            # token which we can use as a dummy mask id
            dummy_mask_idx = sequence_length - 1
        else:
            dummy_mask_idx = spec_aug_mask_idx[0]

        spec_aug_mask_idx = np.concatenate(
            [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
        )
        spec_aug_mask_idxs.append(spec_aug_mask_idx)

    spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)

    # expand masked indices to masked spans
    spec_aug_mask_idxs = np.broadcast_to(
        spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
    )
    spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)

    # add offset to the starting indexes so that indexes now create a span
    offsets = np.arange(mask_length)[None, None, :]
    offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
        batch_size, max_num_masked_span * mask_length
    )
    spec_aug_mask_idxs = spec_aug_mask_idxs + offsets

    # ensure that we cannot have indices larger than sequence_length
    if spec_aug_mask_idxs.max() > sequence_length - 1:
        spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1

    # scatter indices to mask
    np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)

    return spec_aug_mask

class masked_hubert(type(model)):
    def _mask_hidden_states(
          self,
          hidden_states: torch.FloatTensor,
          mask_time_indices: Optional[torch.FloatTensor] = None,
          attention_mask: Optional[torch.LongTensor] = None,
      ):
        """
        Masks extracted features along time axis and/or along feature axis according to
        [SpecAugment](https://arxiv.org/abs/1904.08779).
        """

        # `config.apply_spec_augment` can set masking to False
        if not getattr(self.config, "apply_spec_augment", True):
            return hidden_states

        # generate indices & apply SpecAugment along time axis
        batch_size, sequence_length, hidden_size = hidden_states.size()

        if mask_time_indices is not None:
            # apply SpecAugment along time axis with given mask_time_indices
            hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
        elif self.config.mask_time_prob > 0 and self.training:
            mask_time_indices = _compute_mask_indices(
                (batch_size, sequence_length),
                mask_prob=self.config.mask_time_prob,
                mask_length=self.config.mask_time_length,
                attention_mask=attention_mask,
                min_masks=self.config.mask_time_min_masks,
            )
            mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
            hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)

        if self.config.mask_feature_prob > 0 and self.training:
            # generate indices & apply SpecAugment along feature axis
            mask_feature_indices = _compute_mask_indices(
                (batch_size, hidden_size),
                mask_prob=self.config.mask_feature_prob,
                mask_length=self.config.mask_feature_length,
                min_masks=self.config.mask_feature_min_masks,
            )
            mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
            mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
            hidden_states[mask_feature_indices] = 0

        return hidden_states, mask_time_indices

    def forward(
        self,
        input_values: Optional[torch.Tensor],
        attention_mask: Optional[torch.Tensor] = None,
        mask_time_indices: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) :
        """

        Returns:

        Example:

        ```python
        >>> from transformers import AutoProcessor, HubertModel
        >>> from datasets import load_dataset
        >>> import soundfile as sf

        >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
        >>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")


        >>> def map_to_array(batch):
        ...     speech, _ = sf.read(batch["file"])
        ...     batch["speech"] = speech
        ...     return batch


        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        >>> ds = ds.map(map_to_array)

        >>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values  # Batch size 1
        >>> hidden_states = model(input_values).last_hidden_state
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        extract_features = self.feature_extractor(input_values)
        extract_features = extract_features.transpose(1, 2)

        if attention_mask is not None:
            # compute reduced attention_mask corresponding to feature vectors
            attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)

        hidden_states = self.feature_projection(extract_features)
        hidden_states, mask_time_index = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)

        encoder_outputs = self.encoder(
            hidden_states,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = encoder_outputs[0]

        if not return_dict:
            return (hidden_states,) + encoder_outputs[1:]

        return {
            "last_hidden_state" : hidden_states,
            "hidden_states" : encoder_outputs.hidden_states,
            "attentions" : encoder_outputs.attentions,
            "mask_time_index" : mask_time_index,
        }

In [None]:
class hubert_pretrain(torch.nn.Module):
  def __init__(self, model, codebook_size = 500):
    super(hubert_pretrain, self).__init__()
    self.model = model
    self.emb_size = self.model.config.hidden_size
    self.linear_noise = torch.nn.Linear(self.emb_size, codebook_size)
    self.linear_speech = torch.nn.Linear(self.emb_size, codebook_size)

  def forward(self, data):
    encoded = self.model(data.input_values.to(device), data.attention_mask.to(device))
    noise_logits = self.linear_noise(encoded["last_hidden_state"])
    speech_logits = self.linear_speech(encoded["last_hidden_state"])
    mask_idx = encoded["mask_time_index"]
    return noise_logits, speech_logits, mask_idx

In [None]:
hubert_masked = masked_hubert(model.config)
hubert_masked.load_state_dict(model.state_dict())
hubert_masked= hubert_masked.to(device)

In [None]:
a = processor(librispeech_test[0][0].squeeze(), sampling_rate = 16000, return_tensors = "pt",
              padding = "max_length", max_length = 200000, truncation=True,
              return_attention_mask = True)

In [None]:
class pre_train_dataset(torch.utils.data.Dataset):
  def __init__(self, data, processor, noise_dir, data_label, noise_label):
    self.data = data
    self.processor = processor
    self.noise_dir = noise_dir
    self.data_label = data_label
    self.noise_label = noise_label
    file_paths = []
    for root, dirs, files in os.walk(noise_dir):
        for file in files:
            file_path = os.path.join(root, file)
            if file_path.endswith(".wav"):
              file_paths.append(file_path)

    self.file_paths = file_paths

  def __len__(self):
    return len(self.data)

  def pad_label(self, label, leng):
    result = torch.full(size = (leng, ), fill_value = -100)
    for i in range(min(len(result), len(label))):
      result[i] = label[i]
    return result



  def __getitem__(self, idx):
    clean_speech = self.data[idx][0].squeeze()
    len_noise = 0
    noise = None
    noise_label = None
    ## concatenate noise if noise if noise is shorter than speech
    #############################################################################
    while len_noise < len(clean_speech) or noise is None:
      noise_idx = np.random.choice(len(self.file_paths))
      noise_path = self.file_paths[noise_idx]
      noise_tmp, sample_rate = librosa.load(noise_path, sr=16000)
      if noise is None:
        noise = noise_tmp
        noise_label = self.noise_label[noise_idx]
      else:
        noise = np.concatenate((noise, noise_tmp))
        noise_label = np.concatenate((noise_label, self.noise_label[noise_idx]))
      len_noise = len(noise)
    #############################################################################

    noisy_speech = clean_speech  + noise[:len(clean_speech)]
    encoded = processor(
        noisy_speech[:200000], return_tensors="pt",
        sampling_rate = 16000,padding = "max_length",
        max_length = 200000, truncation=True,
        return_attention_mask = True
                        )
    encoded["speech_label"] = self.pad_label(self.data_label[idx], 624)
    encoded["noise_label"] = self.pad_label(noise_label, 624)
    for k in encoded: encoded[k] = encoded[k].squeeze()
    return encoded

In [None]:
def train():
  hubert_masked = masked_hubert(model.config)
  hubert_masked.load_state_dict(model.state_dict())
  hubert_pt = hubert_pretrain(hubert_masked)
  hubert_pt = hubert_pt.to(device)
  train_dataset = pre_train_dataset(librispeech_train, processor,
                                    noise_dir = "/content/MS-SNSD/noise_train/",
                                    data_label = tag_train,
                                    noise_label = tag_noise)
  test_dataset = pre_train_dataset(librispeech_test, processor,
                                    noise_dir = "/content/MS-SNSD/noise_test/",
                                    data_label = tag_test,
                                    noise_label = tag_noise)
  train_dataloader = DataLoader(train_dataset, batch_size = 1, shuffle = True)
  test_dataloader = DataLoader(test_dataset, batch_size = 1, shuffle = True)
  optimizer = torch.optim.Adam(hubert_masked.parameters(), lr = 1e-5)
  criterion = torch.nn.CrossEntropyLoss()
  for epoch in range(10):
    model.train()
    tbar_train = tqdm(train_dataloader, position=0, leave=True)
    for batch in tbar_train:
      optimizer.zero_grad()
      noise_logits, speech_logits, mask_idx = hubert_pt(batch)
      speech_label_masked = batch["speech_label"].clone().to(device)
      speech_label_unmasked = batch["speech_label"].clone().to(device)
      noise_label_masked = batch["noise_label"].clone().to(device)
      noise_label_unmasked = batch["noise_label"].clone().to(device)
      # print(speech_label_masked.shape)
      # print(mask_idx.shape)
      # print(speech_logits.shape)
      speech_label_masked[~mask_idx] = -100
      noise_label_masked[~mask_idx] = -100
      speech_label_unmasked[mask_idx] = -100
      noise_label_unmasked[mask_idx] = -100
      loss_speech_masked = criterion(speech_logits.view(-1, 500), speech_label_masked.view(-1))
      loss_noise_masked = criterion(noise_logits.view(-1, 500), noise_label_masked.view(-1))
      loss_speech_unmasked = criterion(speech_logits.view(-1, 500), speech_label_unmasked.view(-1))
      loss_noise_unmasked = criterion(noise_logits.view(-1, 500), noise_label_unmasked.view(-1))
      ## add weights here
      loss = loss_speech_masked + loss_noise_masked + loss_speech_unmasked + loss_noise_unmasked
      loss.backward()
      optimizer.step()
      tbar_train.set_postfix(loss=loss.item())



train()



 13%|█▎        | 3848/28539 [09:13<58:30,  7.03it/s, loss=24.3]

In [None]:
for batch in tbar_train:
  encoded = hubert_masked(batch.input_values.to(device), output_hidden_states = True)

{'last_hidden_state': tensor([[[-0.2863, -0.3067,  0.2715,  ..., -0.2437,  0.1441, -0.1282],
         [-0.3438, -0.2656,  0.3441,  ..., -0.2969,  0.0926, -0.1174],
         [-0.2804, -0.2965,  0.3916,  ..., -0.0788,  0.0787, -0.1361],
         ...,
         [ 0.1121, -0.2006,  0.2506,  ...,  0.1507,  0.3141, -1.8060],
         [ 0.0206,  0.3762,  0.2665,  ...,  0.1525, -0.0075, -0.5493],
         [-0.0351,  0.0312,  0.2324,  ..., -0.0284, -0.3100, -1.0807]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>), 'hidden_states': (tensor([[[-0.7155, -0.3656,  0.0486,  ...,  0.1528, -0.3625,  0.0204],
         [-0.5203, -0.2759, -0.2309,  ...,  0.2330, -0.1025,  0.1139],
         [-0.4519, -0.2913,  0.0141,  ..., -0.1149, -0.0696,  0.0000],
         ...,
         [ 0.1207,  0.3528, -0.0000,  ...,  0.1944, -0.2496, -0.3202],
         [ 0.0140,  0.4282,  0.0873,  ..., -0.0322, -0.0218,  0.2393],
         [ 0.1833,  0.4072,  0.1289,  ...,  0.2877, -0.0694,  0.3452]]],
       device='

# Scratch

In [None]:
librispeech_test = torchaudio.datasets.LIBRISPEECH("./data", url="test-clean", download=True)

100%|██████████| 331M/331M [00:20<00:00, 17.0MB/s]


In [None]:
y, sample_rate = librosa.load("/content/MS-SNSD/noise_train/AirConditioner_1.wav", sr=None)

In [None]:
processor(y, return_tensors="pt").input_values.shape

It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 335031])

In [None]:
Audio(data = y[:len(librispeech_test[0][0])] + librispeech_test[0][0] , rate = sample_rate )

In [None]:
from IPython.display import Audio, display
y, sample_rate = librosa.load("/content/MS-SNSD/noise_train/Babble_3.wav", sr=None)
scaling_factor = 2
Audio(data = y[:len(librispeech_test[0][0][0])] * scaling_factor + librispeech_test[0][0][0].numpy() , rate = sample_rate )

In [None]:
librispeech_test[0][0][0]

tensor([0.0003, 0.0003, 0.0004,  ..., 0.0021, 0.0021, 0.0016])

In [None]:
y[:len(librispeech_test[0][0][0])] * scaling_factor

array([ 0.        ,  0.        ,  0.        , ..., -0.1027832 ,
       -0.09197998, -0.07806396], dtype=float32)



|eval split | quantization 1    | quantization 1_2 |quantization 2 |  baseline|
| -------- | -------- | ------- |  -------- | ------- |
| test_clean |    | | |
| test_clean + noise |      || |
| test_other   |   || |

For test_clean, train ASR on the train set **without** adding noise

For test_clean + noise, train ASR on the train set **with** adding noise

For test_other, train ASR on the train set "dev-other" **without** adding noise. (download the dataset using code similar to loading other splits).

You will need to modify the ASR dataset structure accordingly.

Please run the ASR experiment on the baseline model **first** to ensure that the ASR code works correctly.

Note that you do not need to train on the whole train dataset (train it on 10hrs might be enough. See Hubert paper for more details. )

