In [None]:
# no_trainer: https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py


In [1]:
import os, math
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Union

import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import DatasetDict, concatenate_datasets
import datasets
from tqdm.auto import tqdm

from torch.utils.data.dataloader import DataLoader
import transformers
from transformers import (
    AdamW,
    SchedulerType,
    Wav2Vec2Config,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2ForPreTraining,
    get_scheduler,
    #is_wandb_available,
    set_seed,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices

2023-07-07 17:18:45.039064: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
logger = get_logger(__name__)
SAMPLINGRATE=16000
SEED=25
#WANDB_PROJECT='pretrain_wav2vec2_accelerator'
#MODEL_NAME='patrickvonplaten/wav2vec2-base-v2'
MODEL_NAME="patrickvonplaten/mms-300m"
#MODEL_NAME='facebook/mms-300m'
#MODEL_NAME="facebook/wav2vec2-xls-r-300m"
DATASET='/root/openstream/kashmiri/pretrain/pretrain_dataset'
#validation_split_percentage=10
max_duration_in_seconds=25
min_duration_in_seconds=2
audio_column_name='audio'
preprocessing_num_workers=1  #4
cachedir='/root/openstream/kashmiri/pretrain/cache_dir'
train_cache_file_name=f'{cachedir}/train_cache.cache'
validation_cache_file_name=f'{cachedir}/valid_cache.cache'
pad_to_multiple_of=None
learning_rate=.005  #5e-5
gradient_accumulation_steps=8
gradient_checkpointing=True
mask_time_prob=None
mask_time_length=None
pad_to_multiple_of=None
per_device_train_batch_size=8
per_device_eval_batch_size=8
max_train_steps=20000
num_warmup_steps=32000
lr_scheduler_type='linear' 
# ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"] 

gumbel_temperature_decay=0.999995
min_gumbel_temperature=0.5
max_gumbel_temperature=2.0

num_train_epochs=1
logging_steps=500
saving_steps=10000


output_dir='/root/openstream/kashmiri/pretrain/models/mms_pretrained_kashmiri'
set_seed(SEED)

In [3]:
@dataclass
class DataCollatorForWav2Vec2Pretraining:
    model: Wav2Vec2ForPreTraining
    feature_extractor: Wav2Vec2FeatureExtractor
    padding: Union[bool, str] = "longest"
    pad_to_multiple_of: Optional[int] = None
    mask_time_prob: Optional[float] = 0.65
    mask_time_length: Optional[int] = 10
    
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # reformat list to dict and set to pytorch format
        inp_features = [{"input_values": feature["input_values"]} for feature in features]
        batch = self.feature_extractor.pad(
            inp_features,
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        device = batch["input_values"].device
        batch_size = batch["input_values"].shape[0]

        mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
        # make sure masked sequence length is a Python scalar
        mask_indices_seq_length = int(mask_indices_seq_length)

        # make sure that no loss is computed on padded inputs
        if batch.get("attention_mask") is not None:
            # compute real output lengths according to convolution formula
            batch["sub_attention_mask"] = self.model._get_feature_vector_attention_mask(
                mask_indices_seq_length, batch["attention_mask"]
            )

        features_shape = (batch_size, mask_indices_seq_length)

        # sample randomly masked indices
        mask_time_indices = _compute_mask_indices(
            features_shape,
            self.mask_time_prob,
            self.mask_time_length,
            attention_mask=batch.get("sub_attention_mask"),
        )

        # sample negative indices
        sampled_negative_indices = _sample_negative_indices(
            features_shape,
            self.model.config.num_negatives,
            mask_time_indices=mask_time_indices,
        )
        batch["mask_time_indices"] = torch.tensor(mask_time_indices, dtype=torch.long, device=device)
        batch["sampled_negative_indices"] = torch.tensor(sampled_negative_indices, dtype=torch.long, device=device)

        return batch

In [4]:
def multiply_grads(params, c):
    """Multiplies grads by a constant *c*."""
    for p in params:
        if p.grad is not None:
            if torch.is_tensor(c):
                c = c.to(p.grad.device)
            p.grad.data.mul_(c)


def get_grad_norm(params, scale=1):
    """Compute grad norm given a gradient scale."""
    total_norm = 0.0
    for p in params:
        if p.grad is not None:
            param_norm = (p.grad.detach().data / scale).norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm**0.5
    return total_norm


In [5]:
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
accelerator = Accelerator()

logger.info(accelerator.state, main_process_only=False)

datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()

# set up weights and biases if available
# if is_wandb_available():
#     import wandb
#     wandb.init(project=WANDB_PROJECT)


In [6]:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
feature_extractor.do_normalize = True

loading configuration file preprocessor_config.json from cache at /root/.cache/huggingface/hub/models--patrickvonplaten--mms-300m/snapshots/4ee317ce793c53dbc041fc4376c7558292dd38dc/preprocessor_config.json
Feature extractor Wav2Vec2FeatureExtractor {
  "do_normalize": true,
  "feature_extractor_type": "Wav2Vec2FeatureExtractor",
  "feature_size": 1,
  "padding_side": "right",
  "padding_value": 0,
  "return_attention_mask": true,
  "sampling_rate": 16000
}



In [7]:
dset=DatasetDict().load_from_disk(DATASET)

In [8]:
dset

DatasetDict({
    train: Dataset({
        features: ['audio', 'dur'],
        num_rows: 7575
    })
    val: Dataset({
        features: ['audio', 'dur'],
        num_rows: 841
    })
})

In [9]:
max_length = int(max_duration_in_seconds * feature_extractor.sampling_rate)
min_length = int(min_duration_in_seconds * feature_extractor.sampling_rate)

In [10]:
def prepare_dataset(batch):
    sample = batch[audio_column_name]

    inputs = feature_extractor(
        sample["array"], sampling_rate=sample["sampling_rate"], max_length=max_length, truncation=True
        #sample, sampling_rate=SAMPLINGRATE, max_length=max_length, truncation=True, padding=True
    )
    batch["input_values"] = inputs.input_values[0]
    batch["input_length"] = len(inputs.input_values[0])
    
    return batch

In [11]:
cache_file_names=None
if train_cache_file_name is not None:
    cache_file_names = {"train": train_cache_file_name, "val": validation_cache_file_name}
    
with accelerator.main_process_first():
    vectorized_datasets = dset.map(
        prepare_dataset,
        num_proc=preprocessing_num_workers,
        remove_columns=dset["train"].column_names,
        #cache_file_names=cache_file_names,
    )

#     if min_length > 0.0:
#         vectorized_datasets = vectorized_datasets.filter(
#             lambda x: x > min_length,
#             num_proc=preprocessing_num_workers,
#             input_columns=["input_length"],
#         )

    vectorized_datasets = vectorized_datasets.remove_columns("input_length")


Loading cached processed dataset at /root/openstream/kashmiri/pretrain/pretrain_dataset/train/cache-810b03b772a56c4c.arrow
Loading cached processed dataset at /root/openstream/kashmiri/pretrain/pretrain_dataset/val/cache-4405e6582e3a2a32.arrow


In [12]:
vectorized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_values'],
        num_rows: 7575
    })
    val: Dataset({
        features: ['input_values'],
        num_rows: 841
    })
})

In [13]:
#%debug

In [14]:
config = Wav2Vec2Config.from_pretrained(MODEL_NAME)
model = Wav2Vec2ForPreTraining(config)


loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--patrickvonplaten--mms-300m/snapshots/4ee317ce793c53dbc041fc4376c7558292dd38dc/config.json
Model config Wav2Vec2Config {
  "activation_dropout": 0.0,
  "adapter_kernel_size": 3,
  "adapter_stride": 2,
  "add_adapter": false,
  "apply_spec_augment": true,
  "architectures": [
    "Wav2Vec2ForPreTraining"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 1,
  "classifier_proj_size": 256,
  "codevector_dim": 768,
  "contrastive_logits_temperature": 0.1,
  "conv_bias": true,
  "conv_dim": [
    512,
    512,
    512,
    512,
    512,
    512,
    512
  ],
  "conv_kernel": [
    10,
    3,
    3,
    3,
    3,
    2,
    2
  ],
  "conv_stride": [
    5,
    2,
    2,
    2,
    2,
    2,
    2
  ],
  "ctc_loss_reduction": "sum",
  "ctc_zero_infinity": false,
  "diversity_loss_weight": 0.1,
  "do_stable_layer_norm": true,
  "eos_token_id": 2,
  "feat_extract_activation": "gelu",
  "feat_extract_dropout

In [15]:
if gradient_checkpointing:
    model.gradient_checkpointing_enable()

mask_time_prob = config.mask_time_prob if mask_time_prob is None else mask_time_prob
mask_time_length = config.mask_time_length if mask_time_length is None else mask_time_length

data_collator = DataCollatorForWav2Vec2Pretraining(
    model=model,
    feature_extractor=feature_extractor,
    
    pad_to_multiple_of=pad_to_multiple_of,
    mask_time_prob=mask_time_prob,
    mask_time_length=mask_time_length,
)
train_dataloader = DataLoader(
    vectorized_datasets["train"],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=per_device_train_batch_size,
)
eval_dataloader = DataLoader(
    vectorized_datasets["val"], collate_fn=data_collator, batch_size=per_device_eval_batch_size
)

# Optimizer
optimizer = AdamW(
    list(model.parameters()),
    lr=learning_rate,
    betas=[0.9, 0.999],
    eps=1e-8,
)




In [16]:
#dd=next(iter(train_dataloader))

In [17]:
# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)

if max_train_steps is None:
    max_train_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    name=lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=max_train_steps,
)

# Afterwards we recalculate our number of training epochs
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)


In [18]:
len(train_dataloader)

947

In [19]:
max_train_steps,num_train_epochs,accelerator.num_processes

(20000, 169, 1)

In [20]:
# Train
total_batch_size = per_device_train_batch_size * accelerator.num_processes * gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f"  Num examples = {len(vectorized_datasets['train'])}")
logger.info(f"  Num Epochs = {num_train_epochs}")
logger.info(f"  Instantaneous batch size per device = {per_device_train_batch_size}")
logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
logger.info(f"  Total optimization steps = {max_train_steps}")
completed_steps = 0
starting_epoch = 0

progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)

for epoch in range(starting_epoch, num_train_epochs):
    model.train()
    
    for step, batch in enumerate(train_dataloader):
        # compute num of losses
        
        num_losses = batch["mask_time_indices"].sum()
        sub_attention_mask = batch.pop("sub_attention_mask", None)
        sub_attention_mask = (
            sub_attention_mask if sub_attention_mask is not None else torch.ones_like(batch["mask_time_indices"])
        )
        percent_masked = num_losses / sub_attention_mask.sum()
        
        # forward
        outputs = model(**batch)

        # divide loss by gradient accumulation steps since gradients
        # are accumulated for multiple backward passes in PyTorch
        loss = outputs.loss / gradient_accumulation_steps
        assert 1==2
        accelerator.backward(loss)

        # make sure that `num_losses` is summed for distributed training
        # and average gradients over losses of all devices
        if accelerator.state.num_processes > 1:
            num_losses = accelerator.gather_for_metrics(num_losses).sum()
            gradient_multiplier = accelerator.state.num_processes / num_losses
            multiply_grads(model.module.parameters(), gradient_multiplier)
        else:
            multiply_grads(model.parameters(), 1 / num_losses)

        # update step
        if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
            # compute grad norm for monitoring
            scale = (
                accelerator.scaler._scale.item()
                if hasattr(accelerator, "scaler") and accelerator.scaler is not None
                else 1
            )
            if accelerator.state.num_processes > 1:
                grad_norm = get_grad_norm(model.module.parameters(), scale)
            else:
                grad_norm = get_grad_norm(model.parameters(), scale)

            # update parameters
            optimizer.step()
            optimizer.zero_grad()

            if not accelerator.optimizer_step_was_skipped:
                lr_scheduler.step()
            elif accelerator.is_local_main_process:
                progress_bar.write(
                    f"Gradients have overflown - skipping update step... Updating gradient scale to {scale}..."
                )

            # update gumbel temperature
            gumbel_temperature = max(
                max_gumbel_temperature * gumbel_temperature_decay**completed_steps,
                min_gumbel_temperature,
            )
            if hasattr(model, "module"):
                model.module.set_gumbel_temperature(gumbel_temperature)
            else:
                model.set_gumbel_temperature(gumbel_temperature)

            progress_bar.update(1)
            completed_steps += 1
        # logs
        if (step + 1) % (gradient_accumulation_steps * logging_steps) == 0:
            loss.detach()
            outputs.contrastive_loss.detach()
            outputs.diversity_loss.detach()

            if accelerator.state.num_processes > 1:
                loss = accelerator.gather_for_metrics(loss).sum()
                outputs.contrastive_loss = accelerator.gather_for_metrics(outputs.contrastive_loss).sum()
                outputs.diversity_loss = accelerator.gather_for_metrics(outputs.diversity_loss).sum()
                percent_masked = accelerator.gather_for_metrics(percent_masked).sum()

            train_logs = {
                "loss": (loss * gradient_accumulation_steps) / num_losses,
                "constrast_loss": outputs.contrastive_loss / num_losses,
                "div_loss": outputs.diversity_loss / num_losses,
                "%_mask_idx": percent_masked / accelerator.num_processes,
                "ppl": outputs.codevector_perplexity,
                "lr": torch.tensor(optimizer.param_groups[0]["lr"]),
                "temp": torch.tensor(gumbel_temperature),
                "grad_norm": torch.tensor(grad_norm),
            }
            log_str = ""
            for k, v in train_logs.items():
                log_str += "| {}: {:.3e}".format(k, v.item())

            if accelerator.is_local_main_process:
                progress_bar.write(log_str)
#                 if is_wandb_available():
#                     wandb.log(train_logs)
                    
        # save model every `args.saving_steps` steps
        if (step + 1) % (gradient_accumulation_steps * saving_steps) == 0:
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(
                output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
            )
        
        # if completed steps > `args.max_train_steps` stop
        if completed_steps >= max_train_steps:
            break
            
    #validate
    model.eval()
    val_logs = {
        "val_loss": 0,
        "val_contrastive_loss": 0,
        "val_diversity_loss": 0,
        "val_num_losses": 0,
        }
    
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            batch.pop("sub_attention_mask", None)
            outputs = model(**batch)

        val_logs["val_loss"] += outputs.loss
        val_logs["val_contrastive_loss"] += outputs.contrastive_loss
        val_logs["val_diversity_loss"] += outputs.diversity_loss
        val_logs["val_num_losses"] += batch["mask_time_indices"].sum()

    # sum over devices in multi-processing
    if accelerator.num_processes > 1:
        val_logs = {k: accelerator.gather_for_metrics(v).sum() for k, v in val_logs.items()}

    val_logs = {k: v / val_logs["val_num_losses"] for k, v in val_logs.items()}

    log_str = ""
    for k, v in val_logs.items():
        log_str += "| {}: {:.3e}".format(k, v.item())

    if accelerator.is_local_main_process:
        progress_bar.write(log_str)
#         if is_wandb_available():
#             wandb.log(val_logs)

    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(
        output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
    )

  0%|          | 0/20000 [00:00<?, ?it/s]



AssertionError: 

In [None]:
loss

In [21]:
accelerator.backward(loss)

RuntimeError: handle_0 INTERNAL ASSERT FAILED at "../c10/cuda/driver_api.cpp":15, please report a bug to PyTorch. 