In [1]:
import argparse
import math
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Union

import datasets
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import DatasetDict, concatenate_datasets, load_dataset
from huggingface_hub import HfApi
from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm

import transformers
from transformers import (
    AdamW,
    SchedulerType,
    get_scheduler,
    is_wandb_available,
    set_seed,
)
from model import (
    Wav2Vec2ForPreTraining,
    Wav2Vec2Config,
    Wav2Vec2FeatureEncoder
)

from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
from transformers.utils import send_example_telemetry

In [2]:
@dataclass
class DataCollatorForWav2Vec2Pretraining:


    model: Wav2Vec2ForPreTraining
    feature_extractor: Wav2Vec2FeatureEncoder
    padding: Union[bool, str] = "longest"
    pad_to_multiple_of: Optional[int] = None
    mask_time_prob: Optional[float] = 0.65
    mask_time_length: Optional[int] = 10

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # reformat list to dict and set to pytorch format

        input_values = [feature['input_values'][0] for feature in features]  # Note the [0] to get the tensor from the list

        # Wrap input_values in a dictionary
        inputs_dict = {'input_values': input_values}


        batch = self.feature_extractor.pad(
            inputs_dict,
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        device = batch["input_values"].device
        batch_size = batch["input_values"].shape[0]

        mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
        # make sure masked sequence length is a Python scalar
        mask_indices_seq_length = int(mask_indices_seq_length)

        # make sure that no loss is computed on padded inputs
        if batch.get("attention_mask") is not None:
            # compute real output lengths according to convolution formula
            batch["sub_attention_mask"] = self.model._get_feature_vector_attention_mask(
                mask_indices_seq_length, batch["attention_mask"]
            )

        features_shape = (batch_size, mask_indices_seq_length)

        # sample randomly masked indices
        mask_time_indices = _compute_mask_indices(
            features_shape,
            self.mask_time_prob,
            self.mask_time_length,
            attention_mask=batch.get("sub_attention_mask"),
        )

        # sample negative indices
        sampled_negative_indices = _sample_negative_indices(
            features_shape,
            self.model.config.num_negatives,
            mask_time_indices=mask_time_indices,
        )
        batch["mask_time_indices"] = torch.tensor(mask_time_indices, dtype=torch.long, device=device)
        batch["sampled_negative_indices"] = torch.tensor(sampled_negative_indices, dtype=torch.long, device=device)

        return batch


def multiply_grads(params, c):
    """Multiplies grads by a constant *c*."""
    for p in params:
        if p.grad is not None:
            if torch.is_tensor(c):
                c = c.to(p.grad.device)
            p.grad.data.mul_(c)


def get_grad_norm(params, scale=1):
    """Compute grad norm given a gradient scale."""
    total_norm = 0.0
    for p in params:
        if p.grad is not None:
            param_norm = (p.grad.detach().data / scale).norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm**0.5
    return total_norm

In [3]:
accelerator = Accelerator()

In [4]:
set_seed(0)

In [5]:
from dataset import AudioDataset
import random

parent_dir = 'data/mp3_train_files'
file_list = [os.path.join(root, file) 
             for root, _, files in os.walk(parent_dir) 
             for file in files]

random.seed(42)
random.shuffle(file_list)

train_size = int(0.8 * len(file_list))
val_size = int(0.1 * len(file_list))
test_size = len(file_list) - train_size - val_size

train_files = file_list[:train_size]
val_files = file_list[train_size:train_size + val_size]
test_files = file_list[train_size + val_size:]

train_dataset = AudioDataset(train_files)
val_dataset = AudioDataset(val_files)
test_dataset = AudioDataset(test_files)

In [6]:
config = Wav2Vec2Config()
feature_extractor = Wav2Vec2FeatureEncoder()

In [7]:
model = Wav2Vec2ForPreTraining(config)

mask_time_prob = config.mask_time_prob
mask_time_length = config.mask_time_length 

In [8]:
data_collator = DataCollatorForWav2Vec2Pretraining(
        model=model,
        feature_extractor=feature_extractor,
        #pad_to_multiple_of=args.pad_to_multiple_of,
        mask_time_prob=mask_time_prob,
        mask_time_length=mask_time_length,
    )

train_dataloader = DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=data_collator,
        batch_size=8,
    )

eval_dataloader = DataLoader(
        val_dataset, collate_fn=data_collator, batch_size=8
    )

optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=5e-5,
        betas=[0.9, 0.999],
        eps=1e-8,
    )

model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader
    )

In [9]:
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / 1)
max_train_steps = 3 * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
        name='linear',
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=max_train_steps,
    )

num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)

In [10]:
total_batch_size = 8 * accelerator.num_processes * 1
max_gumbel_temperature = 2.0
min_gumbel_temperature = 0.5
gumbel_temperature_decay = 0.999995
logging_steps = 10
gradient_accumulation_steps = 1
saving_steps = 500
push_to_hub = False
output_dir = 'weights'


In [11]:
from torch.utils.tensorboard import SummaryWriter

log_dir = "runs/1"  # Change this to your desired log directory
writer = SummaryWriter(log_dir=log_dir)

In [12]:
progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
starting_epoch = 0

for epoch in range(starting_epoch,num_train_epochs):
    model.train()
   # batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
    for step,batch in enumerate(train_dataloader):

        num_losses = batch["mask_time_indices"].sum()
        sub_attention_mask = batch.pop('sub_attention_mask',None)
        sub_attention_mask = (
                sub_attention_mask if sub_attention_mask is not None else torch.ones_like(batch["mask_time_indices"])
            )
        percent_masked = num_losses / sub_attention_mask.sum()

        outputs = model(**batch)

        loss = outputs.loss
        accelerator.backward(loss)

        multiply_grads(model.parameters(), 1 / num_losses)

        if (step + 1) % 1 == 0 or step == len(train_dataloader) - 1:
            # compute grad norm for monitoring
            scale = (
                accelerator.scaler._scale.item()
                if hasattr(accelerator, "scaler") and accelerator.scaler is not None
                else 1
            )

            grad_norm = get_grad_norm(model.parameters(), scale)

            optimizer.step()
            optimizer.zero_grad()

            if not accelerator.optimizer_step_was_skipped:
                lr_scheduler.step()
            elif accelerator.is_local_main_process:
                progress_bar.write(
                            f"Gradients have overflown - skipping update step... Updating gradient scale to {scale}..."
                        )
            
            # update gumbel temperature
            gumbel_temperature = max(
                        max_gumbel_temperature * gumbel_temperature_decay**completed_steps,
                        min_gumbel_temperature,
                )
            if hasattr(model, "module"):
                    model.module.set_gumbel_temperature(gumbel_temperature)
            else:
                    model.set_gumbel_temperature(gumbel_temperature)

            progress_bar.update(1)
            completed_steps += 1

        if (step + 1) % (1 * logging_steps) == 0:
            loss.detach()
            outputs.contrastive_loss.detach()
            outputs.diversity_loss.detach()

            train_logs = {
                    "loss": (loss * gradient_accumulation_steps) / num_losses,
                    "constrast_loss": outputs.contrastive_loss / num_losses,
                    "div_loss": outputs.diversity_loss / num_losses,
                    "%_mask_idx": percent_masked / accelerator.num_processes,
                    "ppl": outputs.codevector_perplexity,
                    "lr": torch.tensor(optimizer.param_groups[0]["lr"]),
                    "temp": torch.tensor(gumbel_temperature),
                    "grad_norm": torch.tensor(grad_norm),
            }
            log_str = ""
            for k, v in train_logs.items():
                log_str += "| {}: {:.3e}".format(k, v.item())

            if accelerator.is_local_main_process:
                progress_bar.write(log_str)
                for k, v in train_logs.items():
                    writer.add_scalar(f'train/{k}', v.item(), completed_steps)
        
        if (step + 1) % (gradient_accumulation_steps * saving_steps) == 0:
            if (push_to_hub and epoch < num_train_epochs - 1) or output_dir is not None:
                accelerator.wait_for_everyone()
                unwrapped_model = accelerator.unwrap_model(model)
                unwrapped_model.save_pretrained(
                    output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
                )

        if completed_steps >= max_train_steps:
                break
        

    # 7. Validate!

    model.eval()

    # init logs
    val_logs = {
        "val_loss": 0,
        "val_contrastive_loss": 0,
        "val_diversity_loss": 0,
        "val_num_losses": 0,
    }
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            batch.pop("sub_attention_mask", None)
            outputs = model(**batch)

        val_logs["val_loss"] += outputs.loss
        val_logs["val_contrastive_loss"] += outputs.contrastive_loss
        val_logs["val_diversity_loss"] += outputs.diversity_loss
        val_logs["val_num_losses"] += batch["mask_time_indices"].sum()

    # sum over devices in multi-processing
    if accelerator.num_processes > 1:
        val_logs = {k: accelerator.gather_for_metrics(v).sum() for k, v in val_logs.items()}

    val_logs = {k: v / val_logs["val_num_losses"] for k, v in val_logs.items()}

    log_str = ""
    for k, v in val_logs.items():
        log_str += "| {}: {:.3e}".format(k, v.item())

    if accelerator.is_local_main_process:
        progress_bar.write(log_str)
        # Log validation metrics to TensorBoard
        for k, v in val_logs.items():
            writer.add_scalar(f'val/{k}', v.item(), epoch)
        
    if output_dir is not None:
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(
            output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
        )

        

                
writer.close()


  0%|          | 0/4650 [00:00<?, ?it/s]

| loss: 4.723e+00| constrast_loss: 4.635e+00| div_loss: 8.764e-01| %_mask_idx: 4.016e-02| ppl: 7.908e+01| lr: 4.989e-05| temp: 2.000e+00| grad_norm: 3.050e+00
| loss: 4.707e+00| constrast_loss: 4.619e+00| div_loss: 8.826e-01| %_mask_idx: 4.016e-02| ppl: 7.513e+01| lr: 4.978e-05| temp: 2.000e+00| grad_norm: 2.339e+00
| loss: 4.720e+00| constrast_loss: 4.630e+00| div_loss: 8.972e-01| %_mask_idx: 4.016e-02| ppl: 6.580e+01| lr: 4.968e-05| temp: 2.000e+00| grad_norm: 2.109e+00
| loss: 4.710e+00| constrast_loss: 4.621e+00| div_loss: 8.884e-01| %_mask_idx: 4.016e-02| ppl: 7.141e+01| lr: 4.957e-05| temp: 2.000e+00| grad_norm: 1.851e+00
| loss: 4.708e+00| constrast_loss: 4.623e+00| div_loss: 8.594e-01| %_mask_idx: 4.016e-02| ppl: 8.997e+01| lr: 4.946e-05| temp: 2.000e+00| grad_norm: 1.409e+00
| loss: 4.704e+00| constrast_loss: 4.618e+00| div_loss: 8.528e-01| %_mask_idx: 4.016e-02| ppl: 9.424e+01| lr: 4.935e-05| temp: 1.999e+00| grad_norm: 1.576e+00
| loss: 4.704e+00| constrast_loss: 4.619e+00| 

In [13]:
writer.close()