In [11]:
%load_ext autoreload
%autoreload 2
    
import sys
sys.path.append('/Users/xinzheng/workspace/chenlong/services/asr_small')
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import Audio

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import numpy as np

import torch
import torch.nn.functional as func
from torch.utils.data import DataLoader
import whisper
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperModel, WhisperTokenizerFast
from datasets import load_dataset, load_from_disk
import loralib as lora

from smallwhisper import SmallWhisper, SmallWhisperConfig


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
hf_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
tokenizer = WhisperTokenizerFast.from_pretrained("openai/whisper-small")

In [3]:
decoder_start_token_id = hf_model.config.decoder_start_token_id  # <|startoftranscript|>
decoder_prev_token_id = tokenizer.all_special_ids[-3]  # <|startofprev|>


In [4]:
small_ds = load_from_disk('/Users/xinzheng/workspace/chenlong/services/asr_small/one_batch.hf/')

In [5]:
small_ds

Dataset({
    features: ['audio', 'text', 'input_features', 'input_length', 'labels'],
    num_rows: 128
})

In [6]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor ([`Wav2Vec2Processor`])
            The processor used for proccessing the data.
        decoder_start_token_id (:obj: `int`)
            The start-of-sequence token id of the decoder.
        decoder_prev_token_id (:obj: `int`)
            The start-of-prompt token id of the decoder
        input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
            See above for details.
        max_target_length (:obj:`int`, `optional`):
            Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
    """

    processor: Any
    decoder_start_token_id: int
    decoder_prev_token_id: int
    input_padding: Union[bool, str] = "max_length"
    target_padding: Union[bool, str] = "max_length"
    max_target_length: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods

        # dataloader returns a list of features which we convert to a dict
        input_features = {"input_features": [feature["input_features"] for feature in features]}
        label_features = {"input_ids": [feature["labels"] for feature in features]}

        # reformat list to dict and set to pytorch format
        batch = self.processor.feature_extractor.pad(
            input_features,
            padding=self.input_padding,
            return_tensors="pt",
        )

        labels_batch = self.processor.tokenizer.pad(
            label_features,
            max_length=self.max_target_length,
            padding=self.target_padding,
            return_tensors="pt",
        )

        # shift labels to the right to get decoder input ids
        labels = labels_batch["input_ids"]
        decoder_input_ids = labels[:, :-1]
        labels = labels[:, 1:]
        labels_mask = labels_batch.attention_mask[:, 1:]

        # replace padding with -100 to ignore correctly when computing the loss
        labels = labels.masked_fill(labels_mask.ne(1), -100)

        # replace initial prompt tokens with -100 to ignore correctly when computing the loss
        bos_index = torch.argmax((labels == self.decoder_start_token_id).long(), dim=1)
        bos_index = torch.where(bos_index > 0, bos_index + 1, bos_index)
        prompt_mask = torch.arange(labels.shape[1]) < bos_index[:, None]
        labels = torch.where(prompt_mask, -100, labels)

        batch["labels"] = labels
        batch["decoder_input_ids"] = decoder_input_ids

        return batch

In [7]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
            processor=processor,
            decoder_start_token_id=decoder_start_token_id,
            decoder_prev_token_id=decoder_prev_token_id,
            input_padding="longest",
            target_padding="max_length",
            max_target_length=448,
        )

In [8]:
train_dataloader = DataLoader(
                small_ds,
                collate_fn=data_collator,
                batch_size=128,
            )

In [21]:
model = SmallWhisper(SmallWhisperConfig).from_pretrained('small')

loading weights from pretrained gpt: small


  checkpoint = torch.load(fp, map_location=device)


In [24]:
lora.mark_only_lora_as_trainable(model, bias='lora_only')

In [None]:
optimizer = torch.optim.AdamW(optim_groups, lr=3e-4, betas=betas, fused=use_fused)

In [None]:
for batch in train_dataloader:
    