# 05 – Transfer Learning with YAMNet Embeddings

**Course:** CSCI 6366 (Neural Networks and Deep Learning)  
**Project:** Audio Classification using CNN  
**Notebook:** Transfer Learning with Pre-trained Audio Embeddings (YAMNet)

---

## Overview

In this notebook, we extend our previous experiments (02–04) by using
**transfer learning** with a pre-trained audio model, **YAMNet**, to classify
animal sounds (`dog`, `cat`, `bird`).

Instead of training a CNN from scratch on Mel-spectrograms, we:

1. Use a pre-trained YAMNet model (trained on AudioSet) to extract
   high-level audio **embeddings** from each waveform.

2. Train a small neural network (Dense layers) **on top of these embeddings**
   to classify our three animal classes.

3. Compare this transfer-learning approach to our best CNN from
   `04_cnn_full_data.ipynb` (CNN + Dropout 0.3).

**Goals:**

- Reuse the same dataset and class labels as before.
- Keep a similar **train/validation/test split** (stratified, random_state=42).
- Evaluate test accuracy, confusion matrix, and per-class metrics.
- Discuss how transfer learning compares to our custom CNN models.


## 1. Setup and Configuration

We import TensorFlow, TensorFlow Hub (for YAMNet), librosa, NumPy, and sklearn; and set constants similar to notebook 04.


In [2]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

import librosa
import librosa.display

import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras import layers, models

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report

# For reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Paths and constants
DATA_DIR = Path("../data").resolve()
SAMPLE_RATE = 16000  # YAMNet expects 16 kHz audio

CLASS_NAMES = ["dog", "cat", "bird"]
label_to_index = {label: idx for idx, label in enumerate(CLASS_NAMES)}

# Train/val/test ratios (match notebook 04)
TEST_SIZE = 0.15
VAL_SIZE = 0.15  # of the remaining after test split


## 2. Dataset: File Paths and Labels

We reuse the same `data/` directory structure:

- `data/dog/*.wav`
- `data/cat/*.wav`
- `data/bird/*.wav`

Here we collect:

- `file_paths`: list of `Path` objects to WAV files
- `labels`: integer label indices (`0 = dog`, `1 = cat`, `2 = bird`)


In [3]:
def collect_file_paths_and_labels(data_dir: Path):
    """Collect all .wav file paths and integer labels."""
    file_paths = []
    labels = []
    
    for label in CLASS_NAMES:
        class_dir = data_dir / label
        wav_files = sorted(class_dir.glob("*.wav"))
        
        for audio_path in wav_files:
            file_paths.append(audio_path)
            labels.append(label_to_index[label])
    
    return np.array(file_paths), np.array(labels, dtype=np.int32)

file_paths, labels = collect_file_paths_and_labels(DATA_DIR)
print("Total files:", len(file_paths))
for idx, label_name in enumerate(CLASS_NAMES):
    count = np.sum(labels == idx)
    print(f"{label_name}: {count} files")


Total files: 610
dog: 210 files
cat: 207 files
bird: 193 files


## 3. Stratified Train / Validation / Test Split

We create explicit train, validation, and test sets using stratified splits.
We use the **same ratios and random_state=42** as in `04_cnn_full_data.ipynb` so
the splits are comparable:

- Test set: 15% of data
- From the remaining 85%, we take 15% as validation


In [4]:
# First split off test set
paths_train_full, paths_test, y_train_full, y_test = train_test_split(
    file_paths,
    labels,
    test_size=TEST_SIZE,
    random_state=42,
    stratify=labels,
)

# Now split train_full into train and validation
paths_train, paths_val, y_train, y_val = train_test_split(
    paths_train_full,
    y_train_full,
    test_size=VAL_SIZE,
    random_state=42,
    stratify=y_train_full,
)

print("Train size:", len(paths_train))
print("Validation size:", len(paths_val))
print("Test size:", len(paths_test))

def print_split_stats(name, y_split):
    print(f"\n{name} distribution:")
    for idx, label_name in enumerate(CLASS_NAMES):
        count = np.sum(y_split == idx)
        print(f"  {label_name}: {count}")

print_split_stats("Train", y_train)
print_split_stats("Validation", y_val)
print_split_stats("Test", y_test)


Train size: 440
Validation size: 78
Test size: 92

Train distribution:
  dog: 151
  cat: 150
  bird: 139

Validation distribution:
  dog: 27
  cat: 26
  bird: 25

Test distribution:
  dog: 32
  cat: 31
  bird: 29


## 4. YAMNet: Pre-trained Audio Model

We use **YAMNet**, a convolutional neural network trained on Google's
AudioSet dataset. It takes a 16 kHz mono waveform and outputs:

- Class scores (for many audio event classes),
- Intermediate **embeddings** (1024-dimensional vectors),
- Log Mel-spectrogram (internal representation).

For transfer learning, we will:

1. Load each audio file at 16 kHz (mono).
2. Pass the waveform through YAMNet.
3. Take the **average** of all frame-level embeddings to get a single
   1024-D embedding vector per clip.
4. Use these embeddings as input features to a small classifier network.


In [None]:
# YAMNet TF Hub handle (internet needed the first time you run this)
YAMNET_HANDLE = "https://tfhub.dev/google/yamnet/1"

# Fix SSL certificate issue on macOS (if needed)
import ssl
import certifi
import os

# Set SSL certificate path for TensorFlow Hub downloads
os.environ['SSL_CERT_FILE'] = certifi.where()

print("Loading YAMNet from TensorFlow Hub...")
try:
    yamnet_model = hub.load(YAMNET_HANDLE)
    print("YAMNet loaded.")
except Exception as e:
    print(f"Error loading YAMNet: {e}")
    print("\nIf you see an SSL certificate error, try:")
    print("1. Install certifi: pip install certifi")
    print("2. Or run: /Applications/Python\\ 3.11/Install\\ Certificates.command")
    print("3. Or manually download YAMNet and use a local path")
    raise


Loading YAMNet from TensorFlow Hub...


URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1002)>

## 5. Helper Functions: Waveforms → YAMNet Embeddings

We now define:

- `load_waveform(path)`: load a mono waveform at 16 kHz using `librosa`.
- `yamnet_embedding_for_file(path)`: run YAMNet and average frame-level
  embeddings to get a single 1024-D vector.
- `compute_embeddings(file_paths)`: apply this to a list of paths and
  return an array of shape `(N, embedding_dim)`.


In [None]:
def load_waveform(path: Path, sample_rate: int = SAMPLE_RATE) -> np.ndarray:
    """Load a mono waveform at the desired sample rate."""
    y, sr = librosa.load(path, sr=sample_rate, mono=True)
    return y.astype(np.float32)

def yamnet_embedding_for_file(path: Path) -> np.ndarray:
    """Compute a single YAMNet embedding vector for one audio file."""
    waveform = load_waveform(path)  # shape: (num_samples,)
    
    # YAMNet expects a batch dimension: (N, num_samples)
    waveform_tf = tf.convert_to_tensor(waveform, dtype=tf.float32)
    waveform_tf = tf.reshape(waveform_tf, [1, -1])
    
    # YAMNet returns (scores, embeddings, spectrogram)
    scores, embeddings, spectrogram = yamnet_model(waveform_tf)
    # embeddings shape: (num_frames, 1024)
    
    # Average over time frames to get a single 1024-D vector
    embedding_mean = tf.reduce_mean(embeddings, axis=0)  # shape: (1024,)
    return embedding_mean.numpy()

def compute_embeddings(file_paths: np.ndarray) -> np.ndarray:
    """Compute YAMNet embeddings for a list/array of file paths."""
    all_embeddings = []
    for i, path in enumerate(file_paths):
        if i % 50 == 0:
            print(f"Processing file {i}/{len(file_paths)}: {path.name}")
        emb = yamnet_embedding_for_file(path)
        all_embeddings.append(emb)
    return np.stack(all_embeddings, axis=0)


## 6. Build Embeddings for Train / Validation / Test

We now compute YAMNet embeddings for:

- `paths_train` → `X_train_embed`
- `paths_val` → `X_val_embed`
- `paths_test` → `X_test_embed`

Then we convert integer labels (`0,1,2`) into one-hot vectors of length 3.


In [None]:
# Compute embeddings (this may take some time on CPU)
X_train_embed = compute_embeddings(paths_train)
X_val_embed = compute_embeddings(paths_val)
X_test_embed = compute_embeddings(paths_test)

print("Embedding shapes:")
print("  Train:", X_train_embed.shape)
print("  Val:  ", X_val_embed.shape)
print("  Test: ", X_test_embed.shape)

# One-hot encode labels
num_classes = len(CLASS_NAMES)

def to_one_hot(y_int: np.ndarray, num_classes: int) -> np.ndarray:
    y_one_hot = np.zeros((len(y_int), num_classes), dtype=np.float32)
    for i, idx in enumerate(y_int):
        y_one_hot[i, idx] = 1.0
    return y_one_hot

y_train_oh = to_one_hot(y_train, num_classes)
y_val_oh = to_one_hot(y_val, num_classes)
y_test_oh = to_one_hot(y_test, num_classes)

print("Label shapes (one-hot):")
print("  Train:", y_train_oh.shape)
print("  Val:  ", y_val_oh.shape)
print("  Test: ", y_test_oh.shape)


## 7. Plotting Helper for Training Curves

We reuse the same helper function to plot training and validation loss
and accuracy over epochs.


In [None]:
def plot_training_curves(history, title_prefix=""):
    """Plot training and validation loss/accuracy."""
    history_dict = history.history

    train_loss = history_dict.get("loss", [])
    val_loss = history_dict.get("val_loss", [])
    train_acc = history_dict.get("accuracy", [])
    val_acc = history_dict.get("val_accuracy", [])

    epochs = range(1, len(train_loss) + 1)

    plt.figure(figsize=(12, 4))

    # Loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_loss, label="Train loss")
    if val_loss:
        plt.plot(epochs, val_loss, label="Val loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(f"{title_prefix} Training vs Validation Loss")
    plt.legend()

    # Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_acc, label="Train acc")
    if val_acc:
        plt.plot(epochs, val_acc, label="Val acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title(f"{title_prefix} Training vs Validation Accuracy")
    plt.legend()

    plt.tight_layout()
    plt.show()


## 8. Define the Transfer-Learning Classifier

We now define a simple classifier that takes YAMNet embeddings as input:

- Input: 1024-D embedding vector
- Dense(128, ReLU) + Dropout(0.3)
- Dense(3, Softmax)

We use the same loss and optimizer as before (`categorical_crossentropy` +
`adam`).


In [None]:
def build_yamnet_classifier(input_dim: int, num_classes: int = 3, dropout_rate: float = 0.3):
    model = tf.keras.Sequential([
        layers.Input(shape=(input_dim,)),
        layers.Dense(128, activation="relu"),
        layers.Dropout(dropout_rate),
        layers.Dense(num_classes, activation="softmax"),
    ])

    model.compile(
        optimizer="adam",
        loss="categorical_crossentropy",
        metrics=["accuracy"],
    )
    return model

input_dim = X_train_embed.shape[1]
yamnet_model_head = build_yamnet_classifier(input_dim=input_dim, num_classes=num_classes)
yamnet_model_head.summary()


## 9. Training the YAMNet Embedding Classifier

We train the classifier on top of YAMNet embeddings using the same
train/validation split as before.

- Epochs: 20
- Batch size: 16

(You can adjust these if training is very fast or slow.)


In [None]:
EPOCHS = 20
BATCH_SIZE = 16

yamnet_history = yamnet_model_head.fit(
    X_train_embed,
    y_train_oh,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    validation_data=(X_val_embed, y_val_oh),
    verbose=1,
)

plot_training_curves(yamnet_history, title_prefix="YAMNet Embedding Classifier")


## 10. Evaluation on the Held-out Test Set

We now evaluate the YAMNet-based classifier on the test set (92 clips),
and compute:

- Test loss and accuracy
- Confusion matrix
- Per-class precision, recall, and F1-score


In [None]:
def evaluate_model_embeddings(model, X_test, y_test_oh, y_test_int, model_name="Model"):
    """Evaluate a classifier on embedding features and print metrics."""
    test_loss, test_acc = model.evaluate(X_test, y_test_oh, verbose=0)
    print(f"{model_name} - Test loss: {test_loss:.4f}, Test accuracy: {test_acc:.4f}")

    y_pred_probs = model.predict(X_test, verbose=0)
    y_pred = np.argmax(y_pred_probs, axis=1)

    print("\nConfusion matrix:")
    cm = confusion_matrix(y_test_int, y_pred)
    print(cm)

    print("\nClassification report:")
    print(classification_report(y_test_int, y_pred, target_names=CLASS_NAMES))

    return test_loss, test_acc, cm

print("=" * 60)
print("YAMNet Embedding Classifier RESULTS")
print("=" * 60)
yamnet_results = evaluate_model_embeddings(
    yamnet_model_head,
    X_test_embed,
    y_test_oh,
    y_test,
    model_name="YAMNet + Dense Head",
)
