In [16]:
import tensorflow as tf
from transformers import WhisperProcessor, WhisperFeatureExtractor, TFWhisperForConditionalGeneration, WhisperTokenizer
from datasets import load_dataset

In [105]:
feature_extractor = WhisperFeatureExtractor.from_pretrained("avalonai/whisper-small-id")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", predict_timestamps=True)
processor = WhisperProcessor(feature_extractor, tokenizer)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [18]:
model = TFWhisperForConditionalGeneration.from_pretrained("avalonai/whisper-small-id")

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

inputs = feature_extractor(
    ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="tf"
)
input_features = inputs.input_features

generated_ids = model.generate(input_features=input_features)
print(generated_ids)
transcription = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(transcription)

All PyTorch model weights were used when initializing TFWhisperForConditionalGeneration.

All the weights of TFWhisperForConditionalGeneration were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFWhisperForConditionalGeneration for predictions without further training.


tf.Tensor(
[[50258 50259 50359 50363   440   342  1220  4316   295  1331  8795 22949
    433    13   467  2516  3738   281  1565   484   264 41176    13   316
   3554 10460  1472  2706  1585   293 37889    13   316  5139 31433  8666
   2489   365  7852    13 37992   311   419 21193   366   452  2954    13
    316 37889   906  1755   307   264  2368  3278  6702    13 50257]], shape=(1, 59), dtype=int32)
 The stale smell of old beer lingers. It takes heat to bring out the odor. A cold dip restores health and zest. A salt pickle tastes fine with ham. Taco's al pastor are my favorite. A zestful food is the hot cross bun.


In [19]:
model.save('./sane/whisper-id-small')



INFO:tensorflow:Assets written to: ./sane/whisper-id-small\assets


INFO:tensorflow:Assets written to: ./sane/whisper-id-small\assets


In [120]:
class WrapModel(tf.Module):
  def __init__(self, model):
    super(WrapModel, self).__init__()
    self.model = model

  @tf.function(
    input_signature=[
      tf.TensorSpec((1, 80, 3000), tf.float32, name="input_features"),
    ],
  )
  def serving(self, input_features):
    outputs = self.model.generate(
      input_features,
      max_new_tokens=450, 
      return_dict_in_generate=True,
    )
    return {"sequences": outputs["sequences"]}

saved_model_dir = "./sane/whisper-id-small"
tflite_model_path = "./sane/whisper-id-small.tflite"

generate_model = WrapModel(model=model)
tf.saved_model.save(generate_model, saved_model_dir, signatures={"serving_default": generate_model.serving})

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS, 
  tf.lite.OpsSet.SELECT_TF_OPS 
]
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

with open(tflite_model_path, 'wb') as f:
    f.write(tflite_model)

KeyboardInterrupt: 