# 1. Preliminaries

In [None]:
%%capture
!pip install datasets
!pip install evaluate
!pip install jiwer
!pip install wandb
!git clone https://github.com/microsoft/MS-SNSD.git

In [None]:
import torch
import torchaudio
from transformers import AutoProcessor, HubertModel, Wav2Vec2FeatureExtractor, HubertForCTC
import os
import numpy as np
import librosa
from tqdm import tqdm
from IPython.display import Audio, display
import pickle
from typing import Optional, Tuple, Union
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from jiwer import wer
import wandb
from transformers.modeling_outputs import BaseModelOutput
import torch.nn.functional as F

data_dir = './data'
os.makedirs(data_dir, exist_ok=True)

In [None]:
class Config:
    def __init__(self,
                noise_scale = 1,
                masked_speech_loss_ratio = 0.5,
                masked_noise_loss_ratio = 0.5,
                quantization_method = "1",
                name = "facebook/hubert-large-ls960-ft",
                use_contrastive_head = False,
                emb_size_code = 768):
        self.noise_scale = noise_scale
        self.masked_speech_loss_ratio = masked_speech_loss_ratio
        self.unmasked_speech_loss_ratio = 1 - self.masked_speech_loss_ratio
        self.masked_noise_loss_ratio = masked_noise_loss_ratio
        self.unmasked_noise_loss_ratio = 1 - self.masked_noise_loss_ratio
        self.quantization_method = quantization_method
        self.name = name
        self.emb_size_code = emb_size_code
        self.use_contrastive_head = use_contrastive_head
        if quantization_method == "1":
            self.speech_nclass = 500
            self.noise_nclass = 500
        elif quantization_method == "1_2":
            self.speech_nclass = 500
            self.noise_nclass = 100
        elif quantization_method == "2":
            self.speech_nclass = 600
            self.noise_nclass = 600

args = Config(use_contrastive_head = True)

In [None]:
librispeech_dev = torchaudio.datasets.LIBRISPEECH("./data", url="dev-clean", download=True)
librispeech_train = torchaudio.datasets.LIBRISPEECH("./data", url="train-clean-100", download=True)
librispeech_test = torchaudio.datasets.LIBRISPEECH("./data", url="test-clean", download=True)

In [None]:
!gdown 1wVq5bdLcEjR-qKm2UOl1jCwD7hEAQsDj
!gdown 1QHsrC52r8lhsHBPIy9Din0ir6CWM9nAf
!gdown 1FzwfPHgvZoWSR4iwHdUX_SsuujCSib8o
!gdown 160POl10uVDKhFfc_rdIlMZp-ylbkLLKm
!gdown 13hUU13G5n-eCpDvhKuLCVK5x74c8ULvx
!gdown 1FV4HYI0mVAgCfp5xgdLbhM6xbWd0xBY6
!gdown 1ylPd1iH2HLUoxj_0mfjUSQ5kN1uD3P0D
!gdown 1XdNYOMe62f5UG5_rVh6wy36RKjc-OSIe

Downloading...
From: https://drive.google.com/uc?id=1wVq5bdLcEjR-qKm2UOl1jCwD7hEAQsDj
To: /content/kmeans_labels.pkl
100% 68.7M/68.7M [00:01<00:00, 60.6MB/s]
Downloading...
From: https://drive.google.com/uc?id=1QHsrC52r8lhsHBPIy9Din0ir6CWM9nAf
To: /content/kmeans_noise_500.pkl
100% 3.39M/3.39M [00:00<00:00, 211MB/s]
Downloading...
From: https://drive.google.com/uc?id=1FzwfPHgvZoWSR4iwHdUX_SsuujCSib8o
To: /content/kmeans_signal_500.pkl
100% 3.73M/3.73M [00:00<00:00, 200MB/s]
Downloading...
From: https://drive.google.com/uc?id=160POl10uVDKhFfc_rdIlMZp-ylbkLLKm
To: /content/kmeans_labels_1.pkl
100% 68.6M/68.6M [00:00<00:00, 201MB/s]
Downloading...
From: https://drive.google.com/uc?id=13hUU13G5n-eCpDvhKuLCVK5x74c8ULvx
To: /content/kmeans_all_500.pkl
100% 4.67M/4.67M [00:00<00:00, 148MB/s]
Downloading...
From: https://drive.google.com/uc?id=1FV4HYI0mVAgCfp5xgdLbhM6xbWd0xBY6
To: /content/kmeans_labels_2.pkl
100% 68.6M/68.6M [00:01<00:00, 43.8MB/s]
Downloading...
From: https://drive.google.co

In [None]:
try:
  processor = AutoProcessor.from_pretrained(args.name)
  model = HubertModel.from_pretrained(args.name)
except:
  processor = Wav2Vec2FeatureExtractor.from_pretrained(args.name)
  model = HubertModel.from_pretrained(args.name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of HubertModel were not initialized from the model checkpoint at facebook/hubert-large-ls960-ft and are newly initialized: ['hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# 2.Quantization

## Preliminary functions

In [None]:
class ds_noise_mssnsd(torch.utils.data.Dataset):
  def __init__(self, noise_dir, processor, factor = 1):
    self.processor = processor
    self.noise_dir = noise_dir
    self.factor = factor
    file_paths = []
    for root, dirs, files in os.walk(noise_dir):
        for file in files:
            file_path = os.path.join(root, file)
            if file_path.endswith(".wav"):
              file_paths.append(file_path)

    self.file_paths = file_paths

  def __len__(self):
    return len(self.file_paths)

  def __getitem__(self, idx):
    wave_file = self.file_paths[idx]
    y, sample_rate = librosa.load(wave_file, sr=16000)
    y = y * self.factor
    encoded = self.processor(y[:200000], return_tensors="pt", sampling_rate = sample_rate)
    for k in encoded: encoded[k] = encoded[k].squeeze()
    return encoded

class ds_signal_mssnsd(torch.utils.data.Dataset):
  def __init__(self, data, processor, leng = None):
    self.data = data
    self.processor = processor
    self.leng = leng

  def __len__(self):
    if self.leng is not None:
      return self.leng
    else:
      return len(self.data)

  def __getitem__(self, idx):
    clean_speech = self.data[idx][0].squeeze()
    encoded = processor(clean_speech[:200000], return_tensors="pt", sampling_rate = 16000)
    for k in encoded: encoded[k] = encoded[k].squeeze()
    return encoded

## 2.2 Loading exsiting quantization

In [None]:
if args.quantization_method == "1":
  file_name = "kmeans_labels_1.pkl"
  noise_file = "kmeans_noise_500.pkl"
  signal_file = "kmeans_signal_500.pkl"
if args.quantization_method == "1_2":
  file_name = "kmeans_labels_1_2.pkl"
  noise_file = "kmeans_noise_100.pkl"
  signal_file = "kmeans_signal_500.pkl"
elif args.quantization_method == "2":
  file_name = "kmeans_labels_2.pkl"
  noise_file = "kmeans_all_500.pkl"
  signal_file = "kmeans_all_500.pkl"

with open(file_name, 'rb') as file:
  kmeans_labels = pickle.load(file)
with open(noise_file, 'rb') as file:
  kmeans_noise = pickle.load(file)
with open(signal_file, 'rb') as file:
  kmeans_speech = pickle.load(file)

tag_train = kmeans_labels["train"]
tag_test = kmeans_labels["test"]
tag_noise = kmeans_labels["noise"]

# 3. Training

## 3.1 build pe-training model

### 3.1.1 Modify huggingface function

In [None]:
def _compute_mask_indices(
    shape: Tuple[int, int],
    mask_prob: float,
    mask_length: int,
    attention_mask: Optional[torch.LongTensor] = None,
    min_masks: int = 0,
) -> np.ndarray:
    """
    Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
    ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
    CPU as part of the preprocessing during training.

    Args:
        shape: The shape for which to compute masks. This should be of a tuple of size 2 where
               the first element is the batch size and the second element is the length of the axis to span.
        mask_prob:  The percentage of the whole axis (between 0 and 1) which will be masked. The number of
                    independently generated mask spans of length `mask_length` is computed by
                    `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
                    actual percentage will be smaller.
        mask_length: size of the mask
        min_masks: minimum number of masked spans
        attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
                        each batch dimension.
    """
    batch_size, sequence_length = shape

    if mask_length < 1:
        raise ValueError("`mask_length` has to be bigger than 0.")

    if mask_length > sequence_length:
        raise ValueError(
            f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
            f" and `sequence_length`: {sequence_length}`"
        )

    # epsilon is used for probabilistic rounding
    epsilon = np.random.rand(1).item()

    def compute_num_masked_span(input_length):
        """Given input length, compute how many spans should be masked"""
        num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
        num_masked_span = max(num_masked_span, min_masks)

        # make sure num masked span <= sequence_length
        if num_masked_span * mask_length > sequence_length:
            num_masked_span = sequence_length // mask_length

        # make sure num_masked span is also <= input_length - (mask_length - 1)
        if input_length - (mask_length - 1) < num_masked_span:
            num_masked_span = max(input_length - (mask_length - 1), 0)

        return num_masked_span

    # compute number of masked spans in batch
    input_lengths = (
        attention_mask.sum(-1).detach().tolist()
        if attention_mask is not None
        else [sequence_length for _ in range(batch_size)]
    )

    # SpecAugment mask to fill
    spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
    spec_aug_mask_idxs = []

    max_num_masked_span = compute_num_masked_span(sequence_length)

    if max_num_masked_span == 0:
        return spec_aug_mask

    for input_length in input_lengths:
        # compute num of masked spans for this input
        num_masked_span = compute_num_masked_span(input_length)

        # get random indices to mask
        spec_aug_mask_idx = np.random.choice(
            np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
        )

        # pick first sampled index that will serve as a dummy index to pad vector
        # to ensure same dimension for all batches due to probabilistic rounding
        # Picking first sample just pads those vectors twice.
        if len(spec_aug_mask_idx) == 0:
            # this case can only happen if `input_length` is strictly smaller then
            # `sequence_length` in which case the last token has to be a padding
            # token which we can use as a dummy mask id
            dummy_mask_idx = sequence_length - 1
        else:
            dummy_mask_idx = spec_aug_mask_idx[0]

        spec_aug_mask_idx = np.concatenate(
            [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
        )
        spec_aug_mask_idxs.append(spec_aug_mask_idx)

    spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)

    # expand masked indices to masked spans
    spec_aug_mask_idxs = np.broadcast_to(
        spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
    )
    spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)

    # add offset to the starting indexes so that indexes now create a span
    offsets = np.arange(mask_length)[None, None, :]
    offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
        batch_size, max_num_masked_span * mask_length
    )
    spec_aug_mask_idxs = spec_aug_mask_idxs + offsets

    # ensure that we cannot have indices larger than sequence_length
    if spec_aug_mask_idxs.max() > sequence_length - 1:
        spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1

    # scatter indices to mask
    np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)

    return spec_aug_mask

class masked_hubert(type(model)):
    def _mask_hidden_states(
          self,
          hidden_states: torch.FloatTensor,
          mask_time_indices: Optional[torch.FloatTensor] = None,
          attention_mask: Optional[torch.LongTensor] = None,
      ):
        """
        Masks extracted features along time axis and/or along feature axis according to
        [SpecAugment](https://arxiv.org/abs/1904.08779).
        """

        # `config.apply_spec_augment` can set masking to False
        if not getattr(self.config, "apply_spec_augment", True):
            return hidden_states

        # generate indices & apply SpecAugment along time axis
        batch_size, sequence_length, hidden_size = hidden_states.size()

        if mask_time_indices is not None:
            # apply SpecAugment along time axis with given mask_time_indices
            hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
        elif self.config.mask_time_prob > 0 and self.training:
            mask_time_indices = _compute_mask_indices(
                (batch_size, sequence_length),
                mask_prob=self.config.mask_time_prob,
                mask_length=self.config.mask_time_length,
                attention_mask=attention_mask,
                min_masks=self.config.mask_time_min_masks,
            )
            mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
            hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)

        if self.config.mask_feature_prob > 0 and self.training:
            # generate indices & apply SpecAugment along feature axis
            mask_feature_indices = _compute_mask_indices(
                (batch_size, hidden_size),
                mask_prob=self.config.mask_feature_prob,
                mask_length=self.config.mask_feature_length,
                min_masks=self.config.mask_feature_min_masks,
            )
            mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
            mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
            hidden_states[mask_feature_indices] = 0

        return hidden_states, mask_time_indices

    def forward(
        self,
        input_values: Optional[torch.Tensor],
        attention_mask: Optional[torch.Tensor] = None,
        mask_time_indices: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) :
        """

        Returns:

        Example:

        ```python
        >>> from transformers import AutoProcessor, HubertModel
        >>> from datasets import load_dataset
        >>> import soundfile as sf

        >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
        >>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")


        >>> def map_to_array(batch):
        ...     speech, _ = sf.read(batch["file"])
        ...     batch["speech"] = speech
        ...     return batch


        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        >>> ds = ds.map(map_to_array)

        >>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values  # Batch size 1
        >>> hidden_states = model(input_values).last_hidden_state
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        extract_features = self.feature_extractor(input_values)
        extract_features = extract_features.transpose(1, 2)

        if attention_mask is not None:
            # compute reduced attention_mask corresponding to feature vectors
            attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)

        hidden_states = self.feature_projection(extract_features)
        hidden_states, mask_time_index = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)

        encoder_outputs = self.encoder(
            hidden_states,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = encoder_outputs[0]

        if not return_dict:
            return (hidden_states,) + encoder_outputs[1:]

        return {
            "last_hidden_state" : hidden_states,
            "hidden_states" : encoder_outputs.hidden_states,
            "attentions" : encoder_outputs.attentions,
            "mask_time_index" : mask_time_index,
        }

### 3.1.2 Other functions

In [None]:
class hubert_pretrain(torch.nn.Module):
    def __init__(self, model, args, kmeans_noise = None, kmeans_speech = None):
        super(hubert_pretrain, self).__init__()
        self.model = model
        self.emb_size = self.model.config.hidden_size
        self.args = args
        self.kmeans_noise = kmeans_noise
        self.kmeans_speech = kmeans_speech
        speech_out_size = args.speech_nclass if not args.use_contrastive_head else args.emb_size_code
        noise_out_size = args.noise_nclass if not args.use_contrastive_head else args.emb_size_code
        self.linear_noise = torch.nn.Linear(self.emb_size, noise_out_size)
        self.linear_speech = torch.nn.Linear(self.emb_size, speech_out_size)
        self.tau = 0.01
        # assert kmeans_noise is not None and kmeans_speech is not None

    def cos_sim(self, cluster_center, logits):

        # Reshape tensor2 for broadcasting: from [num_class, emb_size] to [1, num_class, emb_size]
        tensor2 = cluster_center.unsqueeze(0)

        # Compute cosine similarity
        # tensor1 shape: [bs, seq_len, emb_size]
        # tensor2 shape: [1, num_class, emb_size]
        # We need tensor1 to be [bs, seq_len, 1, emb_size] and tensor2 to be [1, 1, num_class, emb_size] for broadcasting
        tensor1_expanded = logits.unsqueeze(2)  # Now [bs, seq_len, 1, emb_size]
        tensor2_expanded = tensor2.unsqueeze(0)  # Now [1, 1, num_class, emb_size]

        # Calculate cosine similarity over the last dimension (emb_size)
        cos_sim = F.cosine_similarity(tensor1_expanded, tensor2_expanded, dim=3)

        # cos_sim shape will be [bs, seq_len, num_class]
        return cos_sim


    def forward(self, data):
        encoded = self.model(data.input_values.to(device), data.attention_mask.to(device))
        noise_logits = self.linear_noise(encoded["last_hidden_state"])
        speech_logits = self.linear_speech(encoded["last_hidden_state"])
        mask_idx = encoded["mask_time_index"]
        # print("*" * 20)
        if self.args.use_contrastive_head:
            noise_center = torch.tensor(self.kmeans_noise.cluster_centers_).to(device)
            speech_center = torch.tensor(self.kmeans_speech.cluster_centers_).to(device)
            # print("=" * 20)
            noise_logits = self.cos_sim(noise_center,  noise_logits) /  self.tau
            # print("---" * 20)
            speech_logits = self.cos_sim(speech_center, speech_logits) / self.tau
            # print("//" * 20)
        return noise_logits, speech_logits, mask_idx

In [None]:
class pre_train_dataset(torch.utils.data.Dataset):
    def __init__(self, data, processor, noise_dir, data_label, noise_label, args):
        self.data = data
        self.processor = processor
        self.noise_dir = noise_dir
        self.data_label = data_label
        self.noise_label = noise_label
        file_paths = []
        self.args = args
        for root, dirs, files in os.walk(noise_dir):
            for file in files:
                file_path = os.path.join(root, file)
                if file_path.endswith(".wav"):
                    file_paths.append(file_path)

        self.file_paths = file_paths

    def __len__(self):
        return len(self.data)

    def pad_label(self, label, leng):
        result = torch.full(size = (leng, ), fill_value = -100)
        for i in range(min(len(result), len(label))):
            result[i] = label[i]
        return result



    def __getitem__(self, idx):
        clean_speech = self.data[idx][0].squeeze()
        len_noise = 0
        noise = None
        noise_label = None
        ## concatenate noise if noise if noise is shorter than speech
        #############################################################################
        while len_noise < len(clean_speech) or noise is None:
            noise_idx = np.random.choice(len(self.file_paths))
            noise_path = self.file_paths[noise_idx]
            noise_tmp, sample_rate = librosa.load(noise_path, sr=16000)
            if noise is None:
                noise = noise_tmp
                noise_label = self.noise_label[noise_idx]
            else:
                noise = np.concatenate((noise, noise_tmp))
                noise_label = np.concatenate((noise_label, self.noise_label[noise_idx]))
            len_noise = len(noise)
        #############################################################################

        noisy_speech = clean_speech  + noise[:len(clean_speech)] * self.args.noise_scale
        encoded = processor(
            noisy_speech[:200000], return_tensors="pt",
            sampling_rate = 16000,padding = "max_length",
            max_length = 200000, truncation=True,
            return_attention_mask = True
                            )
        encoded["speech_label"] = self.pad_label(self.data_label[idx], 624)
        encoded["noise_label"] = self.pad_label(noise_label, 624)
        for k in encoded: encoded[k] = encoded[k].squeeze()
        return encoded

## 3.2 Train function

In [None]:
!wandb login

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
def train():
    wandb.init(project="hubert_pretrain_1", entity="yilin_wang")
    hubert_masked = masked_hubert(model.config)
    hubert_masked.load_state_dict(model.state_dict())
    hubert_pt = hubert_pretrain(hubert_masked, args, kmeans_noise, kmeans_speech)
    hubert_pt = hubert_pt.to(device)
    # return hubert_pt ## delete this
    train_dataset = pre_train_dataset(librispeech_train, processor,
                                      noise_dir = "/content/MS-SNSD/noise_train/",
                                      data_label = tag_train,
                                      noise_label = tag_noise,
                                      args = args)
    test_dataset = pre_train_dataset(librispeech_test, processor,
                                      noise_dir = "/content/MS-SNSD/noise_test/",
                                      data_label = tag_test,
                                      noise_label = tag_noise,
                                      args = args)
    train_dataloader = DataLoader(train_dataset, batch_size = 4, shuffle = True)
    test_dataloader = DataLoader(test_dataset, batch_size = 4, shuffle = False)
    optimizer = torch.optim.Adam(hubert_masked.parameters(), lr = 1e-5)
    criterion = torch.nn.CrossEntropyLoss()
    for epoch in range(10):
        hubert_pt.train()
        tbar_train = tqdm(train_dataloader, position=0, leave=True)
        for batch in tbar_train:
            optimizer.zero_grad()
            noise_logits, speech_logits, mask_idx = hubert_pt(batch)
            speech_label_masked = batch["speech_label"].clone().to(device)
            speech_label_unmasked = batch["speech_label"].clone().to(device)
            noise_label_masked = batch["noise_label"].clone().to(device)
            noise_label_unmasked = batch["noise_label"].clone().to(device)
            # print(speech_label_masked.shape)
            # print(mask_idx.shape)
            # print(speech_logits.shape)
            speech_label_masked[~mask_idx] = -100
            noise_label_masked[~mask_idx] = -100
            speech_label_unmasked[mask_idx] = -100
            noise_label_unmasked[mask_idx] = -100
            loss_speech_masked = criterion(speech_logits.view(-1, 500), speech_label_masked.view(-1))
            loss_noise_masked = criterion(noise_logits.view(-1, 500), noise_label_masked.view(-1))
            loss_speech_unmasked = criterion(speech_logits.view(-1, 500), speech_label_unmasked.view(-1))
            loss_noise_unmasked = criterion(noise_logits.view(-1, 500), noise_label_unmasked.view(-1))
            ## add weights here
            loss = args.masked_speech_loss_ratio * loss_speech_masked + \
                  args.unmasked_speech_loss_ratio * loss_speech_unmasked + \
                  args.masked_noise_loss_ratio * loss_noise_masked + \
                  args.unmasked_noise_loss_ratio * loss_noise_unmasked
            loss.backward()
            optimizer.step()
            tbar_train.set_postfix(loss=loss.item())
            wandb.log({"loss": loss.item(),
                      "loss_speech_masked": loss_speech_masked.item(),
                      "loss_noise_masked": loss_noise_masked.item(),
                      "loss_speech_unmasked": loss_speech_unmasked.item(),
                      "loss_noise_unmasked": loss_noise_unmasked.item()})
    wandb.finish()
    return hubert_pt



hubert_trained = train()



[34m[1mwandb[0m: Currently logged in as: [33myilin_wang[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/7135 [00:02<?, ?it/s]

********************
------------------------------------------------------------





OutOfMemoryError: CUDA out of memory. Tried to allocate 7.14 GiB. GPU 0 has a total capacity of 39.56 GiB of which 5.27 GiB is free. Process 82376 has 34.28 GiB memory in use. Of the allocated memory 30.16 GiB is allocated by PyTorch, and 3.63 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

# 3. Evaluation on ASR

In [None]:
class asr_dataset(torch.utils.data.Dataset):
    def __init__(self, data, processor, noise_dir, data_label, noise_label, leng = None):
        self.data = data
        self.processor = processor
        self.noise_dir = noise_dir
        self.data_label = data_label
        self.noise_label = noise_label
        self.leng = leng
        file_paths = []
        for root, dirs, files in os.walk(noise_dir):
            for file in files:
                file_path = os.path.join(root, file)
                if file_path.endswith(".wav"):
                  file_paths.append(file_path)

        self.file_paths = file_paths

    def __len__(self):
        if self.leng is not None:
            return self.leng
        return len(self.data)

    def pad_label(self, label, leng):
        result = torch.full(size = (leng, ), fill_value = -100)
        for i in range(min(len(result), len(label))):
            result[i] = label[i]
        return result



    def __getitem__(self, idx):
        clean_speech = self.data[idx][0].squeeze()
        len_noise = 0
        noise = None
        noise_label = None
        ## concatenate noise if noise if noise is shorter than speech
        #############################################################################
        while len_noise < len(clean_speech) or noise is None:
            noise_idx = np.random.choice(len(self.file_paths))
            noise_path = self.file_paths[noise_idx]
            noise_tmp, sample_rate = librosa.load(noise_path, sr=16000)
            if noise is None:
                noise = noise_tmp
                noise_label = self.noise_label[noise_idx]
            else:
                noise = np.concatenate((noise, noise_tmp))
                noise_label = np.concatenate((noise_label, self.noise_label[noise_idx]))
            len_noise = len(noise)
        #############################################################################

        noisy_speech = clean_speech  + noise[:len(clean_speech)]
        encoded = processor(
            noisy_speech[:200000], return_tensors="pt",
            sampling_rate = 16000,#padding = "max_length",
            # max_length = 200000, truncation=True,
                            )
        encoded["labels"] = processor(text=librispeech_test[idx][2], return_tensors="pt").input_ids
        for k in encoded: encoded[k] = encoded[k].squeeze()
        return encoded

In [None]:
def train_asr(hubert_base):
    hubert_ctc = HubertForCTC.from_pretrained(args.name)
    hubert_ctc.hubert = hubert_base
    hubert_ctc.to(device)
    train_dataset = asr_dataset(librispeech_train, processor,
                                      noise_dir = "/content/MS-SNSD/noise_train/",
                                      data_label = tag_train,
                                      noise_label = tag_noise)
    test_dataset = asr_dataset(librispeech_test, processor,
                                      noise_dir = "/content/MS-SNSD/noise_test/",
                                      data_label = tag_test,
                                      noise_label = tag_noise)
    train_dataloader = DataLoader(train_dataset, batch_size = 1, shuffle = False)
    test_dataloader = DataLoader(test_dataset, batch_size = 1, shuffle = False)
    optimizer = torch.optim.Adam(hubert_ctc.parameters(), lr = 1e-5)

    for epoch in range(10):
        model.train()
        tbar_train = tqdm(train_dataloader, position=0, leave=True)
        for batch in tbar_train:
            optimizer.zero_grad()
            encoded = hubert_ctc(batch.input_values.to(device),
                                labels = batch.labels.to(device)
                                )
            loss = encoded.loss
            loss.backward()
            optimizer.step()
            tbar_train.set_postfix(loss=loss.item())
    return hubert_ctc

base_model = HubertModel.from_pretrained(args.name)
base_model.load_state_dict(hubert_trained.model.state_dict())

hubert_ctc = train_asr(base_model)

Some weights of HubertModel were not initialized from the model checkpoint at facebook/hubert-large-ls960-ft and are newly initialized: ['hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of the model checkpoint at facebook/hubert-large-ls960-ft were not used when initializing HubertForCTC: ['hubert.encoder.pos_conv_embed.conv.weight_g', 'hubert.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing HubertForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing HubertForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceCla

In [None]:
def eval_asr(model, loader):
    model.eval()
    predictions = []
    references = []
    with torch.no_grad():
        tbar_loader = tqdm(loader, position=0, leave=True)
        for batch in tbar_loader:
            input_values = batch['input_values'].to(device)
            # attention_mask = batch.get('attention_mask').to(device) if 'attention_mask' in batch else None
            labels = batch['labels']

            # Forward pass
            logits = model(input_values).logits

            # Decode model output to text
            pred_ids = torch.argmax(logits, dim=-1)
            batch_predictions = processor.batch_decode(pred_ids)
            batch_references = processor.batch_decode(labels)

            # Store predictions and references for WER calculation
            predictions.extend(batch_predictions)
            references.extend(batch_references)

        overall_wer = wer(references, predictions)
        return overall_wer


In [None]:
hubert_ctc = HubertForCTC.from_pretrained(name)
hubert_ctc.to(device)
train_dataset = asr_dataset(librispeech_train, processor,
                                  noise_dir = "/content/MS-SNSD/noise_train/",
                                  data_label = tag_train,
                                  noise_label = tag_noise)
test_dataset = asr_dataset(librispeech_test, processor,
                                  noise_dir = "/content/MS-SNSD/noise_test/",
                                  data_label = tag_test,
                                  noise_label = tag_noise,
                                  leng = 500)
train_dataloader = DataLoader(train_dataset, batch_size = 1, shuffle = False)
test_dataloader = DataLoader(test_dataset, batch_size = 1, shuffle = False)

Some weights of the model checkpoint at facebook/hubert-large-ls960-ft were not used when initializing HubertForCTC: ['hubert.encoder.pos_conv_embed.conv.weight_g', 'hubert.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing HubertForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing HubertForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of HubertForCTC were not initialized from the model checkpoint at facebook/hubert-large-ls960-ft and are newly initialized: ['hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-

In [None]:
eval_asr(hubert_ctc, test_dataloader)

100%|██████████| 500/500 [00:31<00:00, 16.03it/s]


0.6263713729108992