# 05 – Transfer Learning with YAMNet Embeddings

**Course:** CSCI 6366 (Neural Networks and Deep Learning)  
**Project:** Audio Classification using CNN  
**Notebook:** Transfer Learning with Pre-trained Audio Embeddings (YAMNet)

---

## Overview

In this notebook, we extend our previous experiments (02–04) by using
**transfer learning** with a pre-trained audio model, **YAMNet**, to classify
animal sounds (`dog`, `cat`, `bird`).

Instead of training a CNN from scratch on Mel-spectrograms, we:

1. Use a pre-trained YAMNet model (trained on AudioSet) to extract
   high-level audio **embeddings** from each waveform.

2. Train a small neural network (Dense layers) **on top of these embeddings**
   to classify our three animal classes.

3. Compare this transfer-learning approach to our best CNN from
   `04_cnn_full_data.ipynb` (CNN + Dropout 0.3).

**Goals:**

- Reuse the same dataset and class labels as before.
- Keep a similar **train/validation/test split** (stratified, random_state=42).
- Evaluate test accuracy, confusion matrix, and per-class metrics.
- Discuss how transfer learning compares to our custom CNN models.


## 1. Setup and Configuration

We import TensorFlow, TensorFlow Hub (for YAMNet), librosa, NumPy, and sklearn; and set constants similar to notebook 04.


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

import librosa
import librosa.display

import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras import layers, models

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report

# For reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Paths and constants
DATA_DIR = Path("../data").resolve()
SAMPLE_RATE = 16000  # YAMNet expects 16 kHz audio

CLASS_NAMES = ["dog", "cat", "bird"]
label_to_index = {label: idx for idx, label in enumerate(CLASS_NAMES)}

# Train/val/test ratios (match notebook 04)
TEST_SIZE = 0.15
VAL_SIZE = 0.15  # of the remaining after test split


## 2. Dataset: File Paths and Labels

We reuse the same `data/` directory structure:

- `data/dog/*.wav`
- `data/cat/*.wav`
- `data/bird/*.wav`

Here we collect:

- `file_paths`: list of `Path` objects to WAV files
- `labels`: integer label indices (`0 = dog`, `1 = cat`, `2 = bird`)


In [None]:
def collect_file_paths_and_labels(data_dir: Path):
    """Collect all .wav file paths and integer labels."""
    file_paths = []
    labels = []
    
    for label in CLASS_NAMES:
        class_dir = data_dir / label
        wav_files = sorted(class_dir.glob("*.wav"))
        
        for audio_path in wav_files:
            file_paths.append(audio_path)
            labels.append(label_to_index[label])
    
    return np.array(file_paths), np.array(labels, dtype=np.int32)

file_paths, labels = collect_file_paths_and_labels(DATA_DIR)
print("Total files:", len(file_paths))
for idx, label_name in enumerate(CLASS_NAMES):
    count = np.sum(labels == idx)
    print(f"{label_name}: {count} files")
