In [None]:
import tensorflow as tf
import os
import IPython.display as ipd
from scipy.io import wavfile
import numpy as np
import librosa
from utils import (
    get_datasets,
    waveform_to_spectrograms,
    waveform_to_log_mel_spectrogram,
    eval_and_save,
    get_callbacks,
    get_background_noise,
    augment_fn,
)
import matplotlib.pyplot as plt

In [None]:
ds_train_raw, ds_val_raw, ds_test_raw = get_datasets()

In [None]:
class MyModel(tf.keras.Model):
    def __init__(self, encoder):
        super(MyModel, self).__init__()
        self.encoder = encoder
        self.d1 = tf.keras.layers.Dense(128, activation="relu")
        self.d2 = tf.keras.layers.Dense(30)

    def call(self, x):
        input_features, decoder_input_ids = x
        x = self.encoder(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
        x = x[:, -1, :]
        x = self.d1(x)
        return self.d2(x)

In [None]:
from transformers import AutoFeatureExtractor, TFWhisperModel
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny")
model = MyModel(TFWhisperModel.from_pretrained("openai/whisper-tiny"))
model.encoder.trainable = False

batch_size = 2
ds_train = ds_train_raw.batch(batch_size)

optimizer = tf.keras.optimizers.Lion(learning_rate=1e-5)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.load_weights("../epoch_0.weights.h5")

epochs = 2
for epoch in range(1, epochs):
    epoch_loss = 0
    step = 0
    for x_batch, y_batch in ds_train:
        step += 1
        with tf.GradientTape() as tape:
            # Forward pass
            inputs = [feature_extractor(tf.squeeze(x_batch)[i], return_tensors="tf", sampling_rate=16000).input_features for i in range(x_batch.shape[0])]
            inputs = tf.concat(inputs, axis=0)

            tmp = [50258] * x_batch.shape[0] * 2
            decoder_input_ids = tf.convert_to_tensor(tmp)
            decoder_input_ids = tf.reshape(decoder_input_ids, (x_batch.shape[0], 2))

            outputs = model((inputs, decoder_input_ids))

            loss = loss_fn(y_batch, outputs)
            print(loss)
        # Backward pass
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        
        epoch_loss += loss.numpy() * batch_size
        if step % 1000 == 0:
            model.save_weights(f"epoch_{epoch}_step_{step}.weights.h5")

    
    # Print the average loss for the epoch
    print(f"Epoch {epoch + 1}, Loss: {epoch_loss / len(ds_train_raw)}")
    model.save_weights(f"epoch_{epoch}.weights.h5")

In [None]:
ds_test = ds_test_raw.batch(batch_size)

In [None]:
preds = []
for x_batch, y_batch in ds_test:
    inputs = [feature_extractor(tf.squeeze(x_batch)[i], return_tensors="tf", sampling_rate=16000).input_features for i in range(x_batch.shape[0])]
    inputs = tf.concat(inputs, axis=0)

    tmp = [50258] * x_batch.shape[0] * 2
    decoder_input_ids = tf.convert_to_tensor(tmp)
    decoder_input_ids = tf.reshape(decoder_input_ids, (x_batch.shape[0], 2))

    outputs = model((inputs, decoder_input_ids))
    preds.extend(tf.argmax(outputs, axis=1).numpy())

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True)