In [1]:
from transformers import (
    AdamW,
    Wav2Vec2Config,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2ForPreTraining,
    get_scheduler
)
from transformers.models.wav2vec2.modeling_wav2vec2 import (
    _compute_mask_indices, _sample_negative_indices
)
import torch
from torch.utils.data.dataloader import DataLoader, Dataset
import lightning as L
from dataclasses import dataclass
import pandas as pd

In [2]:
from typing import Union, List, Optional, Dict

@dataclass
class DataCollatorForWav2Vec2Pretraining:
    model: Wav2Vec2ForPreTraining
    feature_extractor: Wav2Vec2FeatureExtractor
    padding: Union[bool, str] = "longest"
    pad_to_multiple_of: Optional[int] = None
    mask_time_prob: Optional[float] = 0.65
    mask_time_length: Optional[int] = 10

    def __call__(
            self,
            features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        batch = self.feature_extractor.pad(
            features,
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors='pt'
        )
        device = batch['input_values'].device
        batch_size = batch['input_values'].shape[0]

        mask_indices_seq_length = self.model._get_feat_extract_output_lengths(
            batch["input_values"].shape[-1]
        )
        mask_indices_seq_length = int(mask_indices_seq_length)

        if batch.get("attention_mask") is not None:
            batch['sub_attention_mask'] = self.model._get_feature_vector_attention_mask(
                mask_indices_seq_length, batch["attention_mask"]
            )
        
        features_shape = (batch_size, mask_indices_seq_length)

        # Sample Randomly Masked Indices
        mask_time_indices = _compute_mask_indices(
            features_shape,
            self.mask_time_prob,
            self.mask_time_length,
            attention_mask=batch.get("sub_attention_mask")
        )

        # Sample Negative indices
        sampled_negative_indices = _sample_negative_indices(
            features_shape,
            self.model.config.num_negatives,
            mask_time_indices=mask_time_indices
        )

        batch["mask_time_indices"] = torch.tensor(
            mask_time_indices, dtype=torch.long, device=device
        )
        batch['sampled_negative_indices'] = torch.tensor(
            sampled_negative_indices, dtype=torch.long, device=device
        )
        return batch

In [3]:
def multiply_grads(params, c):
    """Multiplies grads by a constant *c*."""
    for p in params:
        if p.grad is not None:
            if torch.is_tensor(c):
                c = c.to(p.grad.device)
            p.grad.data.mul_(c)

def get_grad_norm(params, scale=1):
    """Compute grad norm given a gradient scale."""
    total_norm = 0.0
    for p in params:
        if p.grad is not None:
            param_norm = (p.grad.detach().data / scale).norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** 0.5
    return total_norm

## Dataset Loading and Preprocessing

In [1]:
from glob import glob
from tqdm.auto import tqdm
from pydub import AudioSegment
audio_paths = []
for audio_path in tqdm(glob("v1/*/v1a/train/*.wav")):
    if audio_path not in audio_paths:
        audio_len = AudioSegment.from_file(audio_path).duration_seconds
    if audio_len > 2 and audio_len < 20:
        audio_paths.append(audio_path)

  0%|          | 0/149826 [00:00<?, ?it/s]

In [2]:
with open("filenames.data.txt", "w") as f:
    f.write("\n".join(audio_paths))

In [4]:
from pydub import AudioSegment
import pandas as pd
audios = pd.DataFrame(open("filenames.data.txt").readlines(), columns=['path'])
audios['path'] = audios['path'].apply(lambda x: x.strip())
from multiprocessing import Pool
from tqdm.auto import tqdm
def get_duration(x):
    return AudioSegment.from_file(x).duration_seconds
with Pool(8) as p:
    audios['duration'] = list(tqdm(p.imap(get_duration, audios['path']), total=len(audios.path)))

  0%|          | 0/67941 [00:00<?, ?it/s]

In [7]:
audios.sort_values('duration', ascending=False, inplace=True)

In [9]:
audios.reset_index(drop=True, inplace=True)

In [10]:
validation_index = int(len(audios) * 0.9)

In [14]:
train, val = audios[:validation_index], audios[validation_index:]

In [15]:
train['path'].to_csv("train.data.txt", header=False, index=False)
val['path'].to_csv("validation.data.txt", header=False, index=False)

## Training Setup

In [4]:
from safetensors.torch import load_file
class AudioDataset(Dataset):
    def __init__(self, audio_paths):
        self.audio_sf_paths = [wav.replace('.wav', '.safetensors') for wav in audio_paths]
        
    def __len__(self):
        return len(self.audio_sf_paths)

    def __getitem__(self, idx):
        batch = load_file(self.audio_sf_paths[idx])
        return batch

In [5]:
import torch
import torch.distributed as dist

def gather_for_metrics(metric_tensor):
    """
    Gathers and sums metrics across all processes in a distributed training environment.
    
    Args:
    metric_tensor (torch.Tensor): A tensor containing the metric to aggregate.

    Returns:
    torch.Tensor: The aggregated metric.
    """
    if not dist.is_initialized():
        raise RuntimeError("Distributed package is not initialized")
    
    # Ensure the tensor is on the correct device
    metric_tensor = metric_tensor.to(dtype=torch.float32)

    # Use all_reduce to sum up all the metrics across all processes
    dist.all_reduce(metric_tensor, op=dist.ReduceOp.SUM)

    return metric_tensor


In [6]:
from typing import Any
from torch.optim.optimizer import Optimizer


class Wav2Vec2PretrainingModule(L.LightningModule):
    def __init__(
            self, 
            model_name: str,
            datasets_path: List[str],
            learning_rate: float = 1e-3, 
            batch_size_per_device: int = 2,
            lr_warmup_steps: int = 10, 
            max_gumbel_temperature: float = 2.0,
            min_gumbel_temperature: float = 0.5,
            gumbel_temperature_decay: float = 0.999995
    ):
        super(Wav2Vec2PretrainingModule, self).__init__()
        self.learning_rate = learning_rate
        self.batch_size = batch_size_per_device
        self.lr_warmup_steps = lr_warmup_steps
        self.model_name = model_name

        config = Wav2Vec2Config.from_pretrained(model_name)
        self.mask_time_prob = config.mask_time_prob if config.mask_time_prob>0 else 0.65
        self.mask_time_length = config.mask_time_length if config.mask_time_length>0 else 10  
        
        self.max_gumbel_temperature = max_gumbel_temperature
        self.min_gumbel_temperature = min_gumbel_temperature
        self.gumbel_temperature_decay = gumbel_temperature_decay
        self.save_hyperparameters()

        self.model = Wav2Vec2ForPreTraining(config)
        self.model.gradient_checkpointing_enable()
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
        self.data_collator = DataCollatorForWav2Vec2Pretraining(
            model=self.model,
            feature_extractor=feature_extractor,
            mask_time_prob=self.mask_time_prob,
            mask_time_length=self.mask_time_length
        )

        train_path, val_path = datasets_path
        self.trainset = AudioDataset(pd.read_csv(train_path, header=None)[0].tolist())
        self.valset = AudioDataset(pd.read_csv(val_path, header=None)[0].tolist())
    
    def train_dataloader(self):
        return DataLoader(
            self.trainset,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=self.data_collator
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.valset,
            batch_size=self.batch_size*2,
            shuffle=False,
            collate_fn=self.data_collator
        )
    
    def training_step(self, batch, batch_idx):
        num_losses = batch['mask_time_indices'].sum()
        sub_attention_mask = batch.pop("sub_attention_mask", None)
        sub_attention_mask = (
            sub_attention_mask if sub_attention_mask is not None else torch.ones_like(batch["masked_time_indices"])
        )

        outputs = self.model(**batch)
    
        if self.trainer.num_nodes>1:
            num_losses = gather_for_metrics(num_losses).sum()
            gradient_multiplier = self.trainer.num_nodes / num_losses
            multiply_grads(self.model.module.parameters(), gradient_multiplier)
        else:
            multiply_grads(self.model.parameters(), 1.0 / num_losses)
        return outputs
    
    def on_train_batch_end(self, outputs, batch: Any, batch_idx: int) -> None:
        # For logging
        loss_log = outputs['loss'].detach()
        contrastive_loss_log = outputs['contrastive_loss'].detach()
        diversity_loss_log = outputs['diversity_loss'].detach()

        if self.trainer.world_size > 1:
            loss_log = gather_for_metrics(loss_log)
            contrastive_loss_log = gather_for_metrics(contrastive_loss_log)
            diversity_loss_log = gather_for_metrics(diversity_loss_log)
        self.log('loss', loss_log, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('contrastive_loss', contrastive_loss_log, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('diversity_loss', diversity_loss_log, on_step=True, on_epoch=True, prog_bar=True, logger=True)
    
    def validation_step(self, batch, batch_idx):
        batch.pop("sub_attention_mask", None)
        outputs = self.model(**batch)
        loss = outputs.loss
        contrastive_loss = outputs.contrastive_loss
        diversity_loss = outputs.diversity_loss
        if self.trainer.world_size > 1:
            loss = gather_for_metrics(loss)
            contrastive_loss = gather_for_metrics(contrastive_loss)
            diversity_loss = gather_for_metrics(diversity_loss)

        num_losses = batch['mask_time_indices'].sum()

        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log("val_contrastive_loss", contrastive_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log("val_diversity_loss", diversity_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log("val_num_losses", num_losses, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)

    def configure_optimizers(self):
        optimizer = AdamW(
            list(self.model.parameters()),
            lr = self.learning_rate,
            betas=(0.9, 0.98),
            eps=1e-6,
            weight_decay=0.01
        )
        lr_scheduler = get_scheduler(
            "linear",
            optimizer= optimizer,
            num_training_steps=self.trainer.estimated_stepping_batches,
            num_warmup_steps=self.lr_warmup_steps
        )
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

    def on_before_optimizer_step(self, optimizer: Optimizer) -> None:
        gumbel_temperature = max(
                self.max_gumbel_temperature * self.gumbel_temperature_decay**(self.global_step%(self.batch_size*self.trainer.num_nodes)),
                self.min_gumbel_temperature,
            )
        if hasattr(self.model, "module"):
            self.model.module.set_gumbel_temperature(gumbel_temperature)
        else:
            self.model.set_gumbel_temperature(gumbel_temperature)

In [None]:
from argparse import ArgumentParser
import os
def parse_args():
    parser = ArgumentParser()
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--train_path", type=str, required=True)
    parser.add_argument("--val_path", type=str, required=True)
    parser.add_argument("--learning_rate", type=float, default=1e-3)
    parser.add_argument("--batch_size", type=int, default=2)
    parser.add_argument("--lr_warmup_steps", type=int, default=32000)
    parser.add_argument("--output_dir", type=str, default="wav2vec2-indic-voices")
    parser.add_argument("--accelerator", type=str, default="gpu")
    parser.add_argument("--devices", type=int, default=-1)
    parser.add_argument("--precision", type=Any[str, int], default=16)
    parser.add_argument('--training_steps', type=int, default=200000)
    parser.add_argument("--accumulate_grad_batches", type=int, default=8)
    parser.add_argument("--gradient_clip_val", type=float, default=8)

    parser.add_argument("--max_gumbel_temperature", type=float, default=2.0)
    parser.add_argument("--min_gumbel_temperature", type=float, default=0.5)
    parser.add_argument("--gumbel_temperature_decay", type=float, default=0.999995)
    parser.add_argument("--save_weights_only", action="store_true")
    parser.add_argument("--save_every_n_steps", type=int, default=10000)
    return parser.parse_args()

## Write CLI Command without comments hash to run this script only for those values that are required, make it multi line


In [10]:
model = Wav2Vec2PretrainingModule(
    model_name="facebook/wav2vec2-large-xlsr-53",
    datasets_path=["train.data.txt", "validation.data.txt"],
    learning_rate=1e-3,
    batch_size_per_device=5,
    lr_warmup_steps=32000
)



In [11]:
ckpting = L.pytorch.callbacks.ModelCheckpoint(
    dirpath="wav2vec2-indic-voices",
    filename="{step:06}.ckpt",
    save_weights_only=True,
    every_n_train_steps=10000,
)
tensorboardlogger = L.pytorch.loggers.TensorBoardLogger(
    save_dir="wav2vec2-indic-voices",
    name="wav2vec2-indic-voices-tb",
    version=0
)
trainer = L.Trainer(
    accelerator="gpu",
    devices=-1,
    precision=16,
    max_steps=200000,
    num_sanity_val_steps=1,
    logger=tensorboardlogger,
    callbacks=[ckpting],
    accumulate_grad_batches=8,
    gradient_clip_val=8,
    gradient_clip_algorithm="norm",
)

/home/hike-e2e/.conda3/envs/research-tts/lib/python3.10/site-packages/lightning/fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [12]:
trainer.fit(
    model
)

/home/hike-e2e/.conda3/envs/research-tts/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/hike-e2e/indicVoices/wav2vec2-indic-voices exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                   | Params | Mode 
---------------------------------------------------------
0 | model | Wav2Vec2ForPreTraining | 317 M  | train
---------------------------------------------------------
317 M     Trainable params
0         Non-trainable params
317 M     Total params
1,269.562 Total estimated model params size (MB)
412       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/hike-e2e/.conda3/envs/research-tts/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/hike-e2e/.conda3/envs/research-tts/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
