In [None]:
# Step 1: Force correct versions
!pip uninstall -y tensorflow-io
!pip install -q tensorflow==2.15.0 tensorflow-io==0.35.0 tensorflow-hub==0.15.0

# Step 2: Restart runtime after install
import os
os.kill(os.getpid(), 9)


Found existing installation: tensorflow-io 0.35.0
Uninstalling tensorflow-io-0.35.0:
  Successfully uninstalled tensorflow-io-0.35.0


In [None]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_io as tfio

In [None]:
print(tf.__version__)
print(hub.__version__)
print(tfio.__version__)

2.15.0
0.15.0
0.35.0


In [None]:
# loading the YAMNet model
yamnet_model = hub.load('https://tfhub.dev/google/yamnet/1')

Loads the original full YAMNet model.

It takes 1D float32 audio samples as input.

Outputs: (scores, embeddings, spectrogram)



In [None]:
# Create a wrapper model that outputs only the pooled embedding
class YamnetEmbeddingModel(tf.Module):
    def __init__(self):
        super().__init__()
        self.yamnet = yamnet_model

    @tf.function(input_signature=[tf.TensorSpec([None], dtype=tf.float32)])
    def __call__(self, audio):
        _, embeddings, _ = self.yamnet(audio)
        pooled = tf.reduce_mean(embeddings, axis=0)  # shape (1024,)
        return {'embedding': pooled}


tf.Module is a custom model class that wraps around YAMNet and filters its output.

In [None]:
# Instantiate the model
wrapper = YamnetEmbeddingModel()


In [None]:

# Save it as a SavedModel
tf.saved_model.save(wrapper, "yamnet_embedding_model")


This model takes raw audio (mono, 16kHz), returns only the embedding

In [None]:

# Convert to TFLite
converter = tf.lite.TFLiteConverter.from_saved_model("yamnet_embedding_model")
converter.optimizations = [tf.lite.Optimize.DEFAULT]  # Optional: for smaller model
tflite_model = converter.convert()


In [None]:
# Save the .tflite model
with open("yamnet_embedding.tflite", "wb") as f:
    f.write(tflite_model)


In [None]:
# downloading the yamnet_embedding.tflite model

In [None]:
from google.colab import files
files.download("yamnet_embedding.tflite")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>