In [1]:
import os
os.environ['HF_HOME'] = 'huggingface'
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = 'True'
import math
from datasets import  load_dataset
from transformers import Wav2Vec2Processor, Wav2Vec2ConformerForCTC, TrainingArguments, Trainer
import torch
from dataclasses import dataclass, field
from typing import Any, Dict, List, Union
import numpy as np
import evaluate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name= 'facebook/wav2vec2-conformer-rel-pos-large-960h-ft'
checkpoint_name= 'checkpoints/checkpoint-750/'

In [3]:
processor = Wav2Vec2Processor.from_pretrained(model_name)

In [4]:
ds = load_dataset('audiofolder', data_dir='train', split='train')  # specify split to return a Dataset object instead of a DatasetDict

Resolving data files: 100%|██████████| 3750/3750 [00:00<00:00, 103070.36it/s]


Downloading and preparing dataset audiofolder/default to /home/alc/TIL-2023/ASR/huggingface/datasets/audiofolder/default-8b45e18fa8565078/0.0.0/6cbdd16f8688354c63b4e2a36e1585d05de285023ee6443ffd71c4182055c0fc...


Downloading data files: 100%|██████████| 3751/3751 [00:00<00:00, 209500.17it/s]
Downloading data files: 0it [00:00, ?it/s]
Extracting data files: 0it [00:00, ?it/s]
                                                        

Dataset audiofolder downloaded and prepared to /home/alc/TIL-2023/ASR/huggingface/datasets/audiofolder/default-8b45e18fa8565078/0.0.0/6cbdd16f8688354c63b4e2a36e1585d05de285023ee6443ffd71c4182055c0fc. Subsequent calls will reuse this data.


In [5]:
ds = ds.train_test_split(test_size=0.2)

In [6]:
ds['train']['audio'][0]

{'path': '/home/alc/TIL-2023/ASR/train/audio/train_02020.wav',
 'array': array([2.13623047e-04, 6.10351562e-05, 0.00000000e+00, ...,
        1.22070312e-04, 1.83105469e-04, 9.15527344e-05]),
 'sampling_rate': 16000}

In [7]:
def prepare_dataset(batch):
    model_name = 'facebook/wav2vec2-conformer-rel-pos-large-960h-ft'
    from transformers import Wav2Vec2Processor
    processor = Wav2Vec2Processor.from_pretrained(model_name)
    batch["input_values"] = [processor(audio["array"], sampling_rate=16000).input_values for audio in batch["audio"]]
    batch["input_length"] = [len(b) for b in batch["input_values"]]
    batch['length'] = batch["input_length"]
    batch["labels"] = processor(text=batch["annotation"]).input_ids
    return batch


ds = ds.map(prepare_dataset, num_proc=8, batched=True, batch_size=256)

                                                                             

In [8]:
@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = "longest"

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"][0]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(input_features, padding=self.padding, return_tensors="pt")

        labels_batch = self.processor.pad(labels=label_features, padding=self.padding, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding="longest")

In [9]:
wer = evaluate.load("wer")
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    return {"wer": wer.compute(predictions=pred_str, references=label_str)}

In [13]:
model = Wav2Vec2ConformerForCTC.from_pretrained(
    model_name,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    mask_time_prob=0.5,  # 0.05
    mask_time_length=10, # 10
    mask_feature_prob=0.5, # 0
    mask_feature_length=10, # 10
)

In [10]:
# If you run this cell, u freeze the feature encoder
# DONT run this cell if you wanna unfreeze the whole model. Right now this gives us better perf.
model.freeze_feature_encoder()

In [14]:
per_gpu_bs = 4
effective_bs = 32
training_args = TrainingArguments(
    output_dir="checkpoints",
    overwrite_output_dir =True,
    per_device_train_batch_size=per_gpu_bs,
    gradient_accumulation_steps=math.ceil(effective_bs/per_gpu_bs),
    learning_rate=1e-4,
    num_train_epochs=20,
    gradient_checkpointing=False,
    fp16=True,
    # bf16=True,  # for A100
    fp16_full_eval=True,
    # bf16_full_eval=True,  # for A100
    group_by_length=True,  # slows down
    evaluation_strategy="epoch",
    save_strategy='epoch',  # epoch
    save_safetensors=True,
    per_device_eval_batch_size=4,
    save_steps=1,
    eval_steps=1,
    logging_steps=100,
    save_total_limit=3,
    lr_scheduler_type='cosine',
    load_best_model_at_end=True,  # True
    adam_beta1=0.9,
    adam_beta2=0.98,  # follow fairseq fintuning config
    warmup_ratio=0.22, # follow Ranger21
    weight_decay=1e-4,  # follow Ranger21
    metric_for_best_model="wer",
    greater_is_better=False,
    report_to=['tensorboard'],
    dataloader_num_workers=24 if os.name != 'nt' else 1)

In [15]:
class CTCTrainer(Trainer):
    def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to train.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.

        Return:
            :obj:`torch.Tensor`: The tensor with training loss on this batch.
        """

        model.train()
        inputs = self._prepare_inputs(inputs)
        loss = self.compute_loss(model, inputs)

        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        if os.name != 'nt':
            accelerator.backward(self.scaler.scale(loss))
            # self.scaler.scale(loss).backward()
        else:
            self.scaler.scale(loss).backward()
        return loss.detach()

In [16]:
if os.name != 'nt':
    from accelerate import Accelerator
    accelerator = Accelerator(mixed_precision='fp16', dynamo_backend='eager')  # FP8 needs transformer_engine package which is only on Linux with Hopper GPUs

In [17]:
def tri_stage_schedule(epoch: int, max_epoch = training_args.num_train_epochs, stage_ratio = [0.1, 0.4, 0.5], peak_lr = training_args.learning_rate, initial_lr_scale=0.01, final_lr_scale=0.05):
    """https://github.com/facebookresearch/fairseq/blob/5ecbbf58d6e80b917340bcbf9d7bdbb539f0f92b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py#L51"""
    assert sum(stage_ratio) == 1
    current_ratio = epoch / max_epoch
    if current_ratio < stage_ratio[0]:  # linear warmup
        lrs = torch.linspace(initial_lr_scale * peak_lr, peak_lr, int(stage_ratio[0] * max_epoch))
        return lrs[epoch]
    elif stage_ratio[0] <= current_ratio <= stage_ratio[1]:  # constant
        return peak_lr
    else:  # exponential decay
        decay_factor = -math.log(final_lr_scale) / (stage_ratio[2] * max_epoch)
        return peak_lr * math.exp(-decay_factor * stage_ratio[2] * max_epoch)

In [18]:
# max_steps = math.ceil(training_args.num_train_epochs * len(ds['train']) / training_args.gradient_accumulation_steps / min(training_args.per_device_train_batch_size, len(ds['train'])))
# optimizer = Ranger21(model.parameters(), num_iterations=max_steps, lr=1e-4)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-8, foreach=False)  # https://github.com/facebookresearch/fairseq/blob/main/examples/wav2vec/config/finetuning/base_960h.yaml
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps)
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=tri_stage_schedule)  # following FAIR finetuning settings
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: x)  # constant LR, stays same throughout, for Ranger21

trainer = CTCTrainer( # TODO: fix the CTCTrainer occupying 24 threads so it doesnt freeze the val
    model=model,
    args=training_args,
    train_dataset=ds['train'],
    eval_dataset=ds['test'],
    tokenizer=processor.feature_extractor,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    # optimizers=(optimizer, scheduler),
)
if os.name != 'nt':  # windows does not support torch.compile yet
    # pass
    trainer.model_wrapped, trainer.optimizer, trainer.lr_scheduler = accelerator.prepare(trainer.model_wrapped, trainer.optimizer, trainer.lr_scheduler)
trainer.train()
if os.name != 'nt':
    accelerator.wait_for_everyone()

   function: 'forward' (/home/alc/venvs/til2023/lib/python3.10/site-packages/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py:806)
   reasons:  ___check_obj_id(self, 140261974126752)
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
   function: 'forward' (/home/alc/venvs/til2023/lib/python3.10/site-packages/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py:563)
   reasons:  ___check_obj_id(self, 140261973855248)
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
   function: 'forward' (/home/alc/venvs/til2023/lib/python3.10/site-packages/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py:660)
   reasons:  ___check_obj_id(self, 140261973854672)
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
   function: 'forward' (/home/alc/venvs/til2023/lib/python3.10/site-packages/transformers/m

Epoch,Training Loss,Validation Loss,Wer
0,No log,0.147936,0.070841
1,1.317300,0.074253,0.038321


   function: '_apply_relative_embeddings' (/home/alc/venvs/til2023/lib/python3.10/site-packages/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py:741)
   reasons:  ___check_obj_id(self, 140261974006640)
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
   function: '__init__' (<string>:2)
   reasons:  tensor 'logits' strides mismatch at index 0. expected 9792, actual 9472
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
   function: 'forward' (/home/alc/venvs/til2023/lib/python3.10/site-packages/transformers/activations.py:149)
   reasons:  tensor 'input' strides mismatch at index 0. expected 239616, actual 831488
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
   function: 'forward' (/home/alc/venvs/til2023/lib/python3.10/site-packages/accelerate/utils/operations.py:520)
   reasons:  tensor 'kwargs['labels']' strides 

KeyboardInterrupt: 

In [None]:
if os.name != 'nt':
    trainer.model_wrapped = accelerator.unwrap_model(trainer.model_wrapped)
trainer.save_model('wav2vec2-conformer')
processor.tokenizer.save_pretrained('wav2vec2-conformer')

# INFERENCE

In [2]:
from transformers import Wav2Vec2Processor
processor = Wav2Vec2Processor.from_pretrained('wav2vec2-conformer')
processor.tokenizer.save_pretrained('checkpoints/checkpoint-4125/')

('checkpoints/checkpoint-4125/tokenizer_config.json',
 'checkpoints/checkpoint-4125/special_tokens_map.json',
 'checkpoints/checkpoint-4125/vocab.json',
 'checkpoints/checkpoint-4125/added_tokens.json')

In [1]:
# Infer
import os
os.environ['HF_HOME'] = 'huggingface'
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = 'True'
import torch
import datasets
from transformers import Wav2Vec2Processor, Wav2Vec2ConformerForCTC
from transformers.pipelines.pt_utils import KeyDataset
from tqdm.auto import tqdm
import pandas as pd
from torch.utils.data import DataLoader
torch.multiprocessing.set_start_method('spawn')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = datasets.load_dataset("test", split="train")
dataset = KeyDataset(KeyDataset(dataset, "audio"), "array")
test_ds = pd.read_csv('Test_Advanced.csv')

Resolving data files: 100%|██████████| 12000/12000 [00:00<00:00, 25783.84it/s]
Found cached dataset audiofolder (/home/alc/TIL-2023/ASR/huggingface/datasets/audiofolder/test-b719a705ed310f32/0.0.0/6cbdd16f8688354c63b4e2a36e1585d05de285023ee6443ffd71c4182055c0fc)


In [3]:
def clean(annotation):
    if "'" in annotation:
        # print(annotation, f'has \' in {annotation}, removing')
        annotation = annotation.split("'")[0] + annotation.split("'")[1][1:]  # Tokenizer includes "'" but TIL dataset does not, remove the S following '
    return annotation

def collate_fn(batch):
    input_values = processor(batch, sampling_rate=16000, return_tensors="pt", padding=True)
    return {"input_values": input_values}

In [4]:
processor = Wav2Vec2Processor.from_pretrained("wav2vec2-conformer")
data_loader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn, pin_memory=True, num_workers=4 if os.name == 'nt' else 0)
checkpoint1 = 'wav2vec2-checkpoints/checkpoint-1593'
# checkpoint2 = 'checkpoints/checkpoint-11250'
model1 = Wav2Vec2ConformerForCTC.from_pretrained(checkpoint1).to('cuda')
# model2 = Wav2Vec2ConformerForCTC.from_pretrained(checkpoint2).to('cuda')
model1.eval()
# model2.eval()
logits1 = []
# logits2 = []
logits = []
with torch.no_grad():
    for batch in tqdm(data_loader):
        inputs = batch['input_values'].to('cuda')
        outputs1 = model1(**inputs).logits
        # outputs2 = model2(**inputs).logits
        logits1.append(outputs1)
        # logits2.append(outputs2)

100%|██████████| 375/375 [02:53<00:00,  2.16it/s]


In [5]:
# logits = [(l1 + l2) / 2 for l1, l2 in zip(logits1, logits2)]
results = []
for l in logits1:
    results.extend(processor.batch_decode(torch.argmax(l, dim=-1)))

In [None]:
test_ds['annotation'] = list(map(clean,results))
test_ds['path'] = test_ds['path'].apply(lambda x: x.split('/')[-1])
test_ds.to_csv('Test_Advanced build in aug 0.7 raw 0.2 unfreeze no adam beta val_0.006305.csv', index=False)  # change file name