In [1]:
import tensorflow as tf
import os
import IPython.display as ipd
from scipy.io import wavfile
import numpy as np
import librosa
from utils import (
    get_datasets,
    waveform_to_spectrograms,
    waveform_to_log_mel_spectrogram,
    eval_and_save,
    get_callbacks,
    get_background_noise,
    augment_fn,
)
import matplotlib.pyplot as plt

2024-05-05 16:39:36.991397: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
ds_train_raw, ds_val_raw, ds_test_raw = get_datasets()

Found 51088 files belonging to 30 classes.


2024-05-05 16:39:55.210899: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:984] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-05-05 16:39:55.575874: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:984] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-05-05 16:39:55.576032: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:984] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-05-05 16:39:55.581416: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:984] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-05-05 16:39:55.581840: I external/local_xla/xla/stream_executor

Found 6798 files belonging to 30 classes.
Found 6835 files belonging to 30 classes.


In [3]:
class MyModel(tf.keras.Model):
    def __init__(self, encoder):
        super(MyModel, self).__init__()
        self.encoder = encoder
        self.d1 = tf.keras.layers.Dense(128, activation="relu")
        self.d2 = tf.keras.layers.Dense(30)

    def call(self, x):
        input_features, decoder_input_ids = x
        x = self.encoder(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
        x = x[:, -1, :]
        x = self.d1(x)
        return self.d2(x)

In [4]:
from transformers import AutoFeatureExtractor, TFWhisperModel
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny")
model = MyModel(TFWhisperModel.from_pretrained("openai/whisper-tiny"))
model.encoder.trainable = False

batch_size = 2
ds_train = ds_train_raw.batch(batch_size)

optimizer = tf.keras.optimizers.Lion(learning_rate=1e-5)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)


epochs = 3
for epoch in range(epochs):
    epoch_loss = 0
    step = 0
    for x_batch, y_batch in ds_train:
        step += 1
        with tf.GradientTape() as tape:
            # Forward pass
            inputs = [feature_extractor(tf.squeeze(x_batch)[i], return_tensors="tf", sampling_rate=16000).input_features for i in range(x_batch.shape[0])]
            inputs = tf.concat(inputs, axis=0)

            tmp = [50258] * x_batch.shape[0] * 2
            decoder_input_ids = tf.convert_to_tensor(tmp)
            decoder_input_ids = tf.reshape(decoder_input_ids, (x_batch.shape[0], 2))

            outputs = model((inputs, decoder_input_ids))

            loss = loss_fn(y_batch, outputs)
            print(loss)
        # Backward pass
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        
        epoch_loss += loss.numpy() * batch_size

        if step == 1 and epoch == 1:
            model.load_weights("../models/epoch_0.weights.h5")

        if step % 1000 == 0:
            model.save_weights(f"epoch_{epoch}_step_{step}.weights.h5")

    
    # Print the average loss for the epoch
    print(f"Epoch {epoch + 1}, Loss: {epoch_loss / len(ds_train_raw)}")
    model.save_weights(f"epoch_{epoch}.weights.h5")

All PyTorch model weights were used when initializing TFWhisperModel.

All the weights of TFWhisperModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFWhisperModel for predictions without further training.
2024-05-05 05:19:36.205112: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8907


tf.Tensor(18.127167, shape=(), dtype=float32)
tf.Tensor(0.00094031537, shape=(), dtype=float32)
tf.Tensor(0.0028146363, shape=(), dtype=float32)
tf.Tensor(0.38312045, shape=(), dtype=float32)
tf.Tensor(0.35935244, shape=(), dtype=float32)
tf.Tensor(0.0015949176, shape=(), dtype=float32)
tf.Tensor(1.4084885, shape=(), dtype=float32)
tf.Tensor(0.1125016, shape=(), dtype=float32)
tf.Tensor(0.08510264, shape=(), dtype=float32)
tf.Tensor(0.013127434, shape=(), dtype=float32)
tf.Tensor(0.013230459, shape=(), dtype=float32)
tf.Tensor(0.19136447, shape=(), dtype=float32)
tf.Tensor(0.03973258, shape=(), dtype=float32)
tf.Tensor(0.18771876, shape=(), dtype=float32)
tf.Tensor(0.0063643763, shape=(), dtype=float32)
tf.Tensor(3.3902535, shape=(), dtype=float32)
tf.Tensor(0.003877073, shape=(), dtype=float32)
tf.Tensor(0.41063297, shape=(), dtype=float32)
tf.Tensor(0.015422547, shape=(), dtype=float32)
tf.Tensor(0.016126102, shape=(), dtype=float32)
tf.Tensor(0.0022048794, shape=(), dtype=float32)
t

2024-05-05 10:54:11.710213: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [5]:
ds_test = ds_test_raw.batch(batch_size)

In [31]:
preds = []
loss_total = 0
correct = 0
for x_batch, y_batch in ds_test:
    if x_batch.shape[0] != batch_size:
        continue
    inputs = [feature_extractor(tf.squeeze(x_batch)[i], return_tensors="tf", sampling_rate=16000).input_features for i in range(x_batch.shape[0])]
    inputs = tf.concat(inputs, axis=0)

    tmp = [50258] * x_batch.shape[0] * 2
    decoder_input_ids = tf.convert_to_tensor(tmp)
    decoder_input_ids = tf.reshape(decoder_input_ids, (x_batch.shape[0], 2))

    outputs = model((inputs, decoder_input_ids))
    loss = loss_fn(y_batch, outputs)
    loss_total += loss.numpy() * batch_size

    correct += np.sum(tf.argmax(outputs, axis=1).numpy() == y_batch.numpy())

    preds.extend(tf.argmax(outputs, axis=1).numpy())

2024-05-05 16:14:03.469547: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [12]:
correct / len(ds_test_raw), loss_total / len(ds_test_raw)

(0.945476547654, 0.2779809802487198)