<a href="https://colab.research.google.com/github/Volbis/cough_ai/blob/main/notebooks/train_data_efficient_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

~~~
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
~~~

# Classifying Pneumonia with HeAR and COUGHVID Dataset

<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/Volbis/cough_ai/blob/main/notebooks/train_data_efficient_classifier.ipynb">
      <img alt="Google Colab logo" src="https://www.tensorflow.org/images/colab_logo_32px.png" width="32px"><br> Run in Google Colab
    </a>
  </td>  
  <td style="text-align: center">
    <a href="https://github.com/Volbis/cough_ai/blob/main/notebooks/train_data_efficient_classifier.ipynb">
      <img alt="GitHub logo" src="https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png" width="32px"><br> View on GitHub
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://huggingface.co/google/hear">
      <img alt="Hugging Face logo" src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" width="32px"><br> View on Hugging Face
    </a>
  </td>
</tr></tbody></table>


This Colab notebook demonstrates how to use the HeAR (Health Acoustic Representations) model, directly from Hugging Face, to create and utilize embeddings from health-related audio for pneumonia classification. The notebook focuses on building a data-efficient pneumonia classifier system using the COUGHVID dataset from KaggleHub.

Embeddings are compact, numerical representations of audio data that capture important features, making them suitable for training machine learning models with limited data and computational resources. Learn more about embeddings and their benefits at [this page](https://developers.google.com/health-ai-developer-foundations/hear).

#### Here's a breakdown of the notebook's steps:

1.  **Model Loading:** The HeAR model is loaded from the Hugging Face Hub (requires authentication with your Hugging Face account).

2.  **Dataset Loading:**
    *   **COUGHVID Dataset:** Audio files and labels are downloaded from KaggleHub (orvile/coughvid-v3 dataset).
    *   **Label Extraction:** Pneumonia labels are extracted from the CSV file, considering COVID-19 and lower respiratory infections as pneumonia cases.

3.  **Embedding Generation:**
    *   **Preprocessing:** Audio files are loaded and processed using `librosa`, resampled to 16kHz (required by HeAR model) and segmented into 2-second clips.
    *   **Inference:** The preprocessed 2-second audio clips are fed to the HeAR model to generate embeddings. Each clip produces a 512-dimensional HeAR embedding vector.
    *   **Visualization (Optional):** The notebook includes functions to display the audio waveform, Mel spectrogram, and an audio player for each file and its individual clips.

4.  **Classifier Training:**
    *   **Train/Test Split:** Data is split into training and testing sets with stratification based on pneumonia labels.
    *   **Model Selection:** Several scikit-learn classifiers are trained, including:
        *   Support Vector Machine (linear kernel)
        *   Logistic Regression
        *   Gradient Boosting
        *   Random Forest
        *   Multi-layer Perceptron (MLP)
    *   **Training:** Each classifier is trained using the generated HeAR embeddings and pneumonia labels.

5.  **Pneumonia Classification:**
    *   **Evaluation:** Trained classifiers are evaluated on the test set.
    *   **Prediction:** The models predict whether audio clips contain signs of pneumonia.

6.  **Embedding Visualization:**
    *   **PCA Plot:** A plot visualizing the data points in a PCA space, colored by pneumonia status.
    *   **Barcode Visualization:** The embeddings are visualized as "barcodes" showing the magnitude of each dimension.

## Import Required Libraries

In [None]:
import os
import pandas as pd
import numpy as np
import librosa
import matplotlib.pyplot as plt
import librosa.display
import matplotlib.cm as cm
import warnings
from IPython.display import Audio
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier

# Suppress warnings
warnings.filterwarnings("ignore", category=UserWarning, module="soundfile")
warnings.filterwarnings("ignore", module="librosa")

## Define Helper Functions

In [None]:
def plot_waveform(sound, sr, title, figsize=(12, 4), color='blue', alpha=0.7):
    """Plots the waveform of the audio using librosa.display."""
    plt.figure(figsize=figsize)
    librosa.display.waveshow(sound, sr=sr, color=color, alpha=alpha)
    plt.title(f"{title}\nshape={sound.shape}, sr={sr}, dtype={sound.dtype}")
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.grid(True)
    plt.tight_layout()
    plt.show()


def plot_spectrogram(sound, sr, title, figsize=(12, 4), n_fft=2048, hop_length=256, n_mels=128, cmap='nipy_spectral'):
    """Plots the Mel spectrogram of the audio using librosa."""
    plt.figure(figsize=figsize)
    mel_spectrogram = librosa.feature.melspectrogram(y=sound, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
    log_mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)
    librosa.display.specshow(log_mel_spectrogram, sr=sr, hop_length=hop_length, x_axis='time', y_axis='mel', cmap=cmap)
    plt.title(f"{title} - Mel Spectrogram")
    plt.tight_layout()
    plt.show()

## Authenticate with HuggingFace

In [None]:
from huggingface_hub.utils import HfFolder

if HfFolder.get_token() is None:
    from huggingface_hub import notebook_login
    notebook_login()

## Setup HeAR Model from Hugging Face

In [None]:
from huggingface_hub import from_pretrained_keras

# Load the model directly from Hugging Face Hub
loaded_model = from_pretrained_keras("google/hear")
# Inference function for embedding generation
infer = loaded_model.signatures["serving_default"]

# HeAR Parameters
SAMPLE_RATE = 16000  # Samples per second (Hz)
CLIP_DURATION = 2    # Duration of the audio clip in seconds
CLIP_LENGTH = SAMPLE_RATE * CLIP_DURATION  # Total number of samples

print(f"HeAR Model loaded successfully!")
print(f"Sample Rate: {SAMPLE_RATE} Hz")
print(f"Clip Duration: {CLIP_DURATION} seconds")
print(f"Clip Length: {CLIP_LENGTH} samples")

## Download COUGHVID Dataset from KaggleHub

In [None]:
%%time
import kagglehub

# Download latest version of the COUGHVID dataset
print("Downloading COUGHVID dataset from KaggleHub...")
dataset_path = kagglehub.dataset_download("orvile/coughvid-v3")

print(f"\nDataset downloaded to: {dataset_path}")
print(f"\nDataset structure:")
for root, dirs, files in os.walk(dataset_path):
    level = root.replace(dataset_path, '').count(os.sep)
    indent = ' ' * 2 * level
    print(f'{indent}{os.path.basename(root)}/')
    subindent = ' ' * 2 * (level + 1)
    for file in files[:5]:  # Show first 5 files per directory
        print(f'{subindent}{file}')
    if len(files) > 5:
        print(f'{subindent}... and {len(files) - 5} more files')

## Load Pneumonia Labels from CSV

In [None]:
# Load the labels CSV file
csv_path = os.path.join(dataset_path, 'tabular_form', 'tabular_form', 'filtered_expert_labels_coughvid_v3.csv')
labels_df = pd.read_csv(csv_path)

print(f"Loaded labels from: {csv_path}")
print(f"Shape: {labels_df.shape}")
print(f"\nColumn names:")
print(labels_df.columns.tolist())
print(f"\nFirst few rows:")
print(labels_df.head())
print(f"\nLabel value counts:")
if 'covid_status' in labels_df.columns:
    print(labels_df['covid_status'].value_counts())
if 'cough_status' in labels_df.columns:
    print(labels_df['cough_status'].value_counts())

## Create Pneumonia Labels Dictionary and Files Map

In [None]:
# Create pneumonia labels dictionary
# Consider 'COVID-19' and 'lower_infection' as pneumonia (True)
# Consider 'healthy' and other categories as non-pneumonia (False)
pneumonia_labels = {}

for idx, row in labels_df.iterrows():
    # Get the filename without .json extension
    if 'uuid' in labels_df.columns:
        filename_base = str(row['uuid'])
    elif 'filename' in labels_df.columns:
        filename_base = str(row['filename']).replace('.json', '')
    else:
        # Use first column as filename
        filename_base = str(row[0]).replace('.json', '')

    # Determine pneumonia status based on available columns
    has_pneumonia = False
    if 'covid_status' in labels_df.columns:
        covid_status = str(row['covid_status']).lower()
        if 'covid' in covid_status or 'positive' in covid_status:
            has_pneumonia = True

    # Check for respiratory infection indicators
    if 'cough_status' in labels_df.columns:
        cough_status = str(row['cough_status']).lower()
        if 'lower' in cough_status or 'infection' in cough_status:
            has_pneumonia = True

    pneumonia_labels[filename_base] = has_pneumonia

# Find audio files in the dataset
audio_extensions = ['.wav', '.ogg', '.webm', '.mp3', '.flac']
files_map = {}
audio_dir = os.path.join(dataset_path, 'audio_form', 'audio_form')

if os.path.exists(audio_dir):
    for filename in os.listdir(audio_dir):
        if any(filename.endswith(ext) for ext in audio_extensions):
            file_path = os.path.join(audio_dir, filename)
            # Remove extension to match with labels
            file_base = os.path.splitext(filename)[0]
            # Only add files that have labels
            if file_base in pneumonia_labels:
                files_map[filename] = file_path

print(f"\nPneumonia labels created: {len(pneumonia_labels)} entries")
print(f"Pneumonia cases: {sum(pneumonia_labels.values())}")
print(f"Non-pneumonia cases: {len(pneumonia_labels) - sum(pneumonia_labels.values())}")
print(f"\nAudio files found with labels: {len(files_map)}")
print(f"Audio directory: {audio_dir}")

# Initialize embedding cache
file_embeddings = {}

## Audio Processing Configuration

In [None]:
# Audio display and processing options
SHOW_WAVEFORM = False
SHOW_SPECTROGRAM = False  # Set to True to see spectrograms
SHOW_PLAYER = False
SHOW_CLIPS = False

# Clips of length CLIP_DURATION seconds are extracted from the audio file
# using a sliding window. Adjacent clips are overlapped by CLIP_OVERLAP_PERCENT.
CLIP_OVERLAP_PERCENT = 10

# When True, if a clip extracted from the file is quieter than
# the SILENCE_RMS_THRESHOLD_DB it is not sent to the HeAR model.
CLIP_IGNORE_SILENT_CLIPS = True
# Maximum average amplitude of a frame to be considered silence.
SILENCE_RMS_THRESHOLD_DB = -50

# Limit number of files to process (set to None for all files)
MAX_FILES_TO_PROCESS = 100  # Adjust based on your needs

print(f"Audio processing configuration:")
print(f"  Clip duration: {CLIP_DURATION}s")
print(f"  Clip overlap: {CLIP_OVERLAP_PERCENT}%")
print(f"  Ignore silent clips: {CLIP_IGNORE_SILENT_CLIPS}")
print(f"  Silence threshold: {SILENCE_RMS_THRESHOLD_DB} dB")
print(f"  Max files to process: {MAX_FILES_TO_PROCESS if MAX_FILES_TO_PROCESS else 'All'}")

## Load Audio Files and Generate HeAR Embeddings

In [None]:
%%time

# Process files
files_to_process = list(files_map.items())
if MAX_FILES_TO_PROCESS:
    files_to_process = files_to_process[:MAX_FILES_TO_PROCESS]

print(f"Processing {len(files_to_process)} audio files...\n")

for idx, (file_key, file_path) in enumerate(files_to_process):
    try:
        # Load the audio file into numpy array with specified sample rate and 1 channel (mono).
        if (idx + 1) % 10 == 0:
            print(f"Processing file {idx + 1}/{len(files_to_process)}: {file_key}")

        audio, sample_rate = librosa.load(file_path, sr=SAMPLE_RATE, mono=True)

        # Display audio file (optional)
        if SHOW_WAVEFORM:
            plot_waveform(audio, sample_rate, title=file_key, color='blue')
        if SHOW_SPECTROGRAM:
            plot_spectrogram(audio, sample_rate, file_key, n_fft=2*1024, hop_length=64, n_mels=256, cmap='Blues')
        if SHOW_PLAYER:
            display(Audio(data=audio, rate=sample_rate))

        # Segment audio into overlapping clips
        clip_batch = []
        overlap_samples = int(CLIP_LENGTH * (CLIP_OVERLAP_PERCENT / 100))
        step_size = CLIP_LENGTH - overlap_samples
        num_clips = max(1, (len(audio) - overlap_samples) // step_size)

        for i in range(num_clips):
            start_sample = i * step_size
            end_sample = start_sample + CLIP_LENGTH
            clip = audio[start_sample:end_sample]

            # Pad clip with zeros if less than the required CLIP_LENGTH
            if end_sample > len(audio):
                clip = np.pad(clip, (0, CLIP_LENGTH - len(clip)), 'constant')

            # Calculate average loudness of the clip (in dB)
            rms_loudness = round(20 * np.log10(np.sqrt(np.mean(clip**2)) + 1e-10))

            # Skip if clip is too quiet
            if CLIP_IGNORE_SILENT_CLIPS and rms_loudness < SILENCE_RMS_THRESHOLD_DB:
                continue

            # Add clip to batch
            clip_batch.append(clip)

        # Perform HeAR batch inference to extract the associated clip embedding
        if len(clip_batch) > 0:
            clip_batch = np.asarray(clip_batch)
            if file_key not in file_embeddings:
                embedding_batch = infer(x=clip_batch)['output_0'].numpy()
                file_embeddings[file_key] = embedding_batch

    except Exception as e:
        print(f"Error processing {file_key}: {str(e)}")
        continue

print(f"\nProcessing complete!")
print(f"Successfully processed: {len(file_embeddings)} files")
print(f"Total embeddings generated: {sum(len(emb) for emb in file_embeddings.values())}")

## Prepare Training and Test Sets for Pneumonia Classification

In [None]:
# Collect all embeddings and their corresponding labels
all_embeddings = []
all_labels = []
all_file_names = []

for file_key, embedding_batch in file_embeddings.items():
    # Get the base filename without extension
    file_base = os.path.splitext(file_key)[0]

    # Get the pneumonia label for this file
    if file_base in pneumonia_labels:
        label = 1 if pneumonia_labels[file_base] else 0

        # Add each embedding from this file
        for embedding in embedding_batch:
            all_embeddings.append(embedding)
            all_labels.append(label)
            all_file_names.append(file_key)

# Convert to numpy arrays
all_embeddings = np.array(all_embeddings)
all_labels = np.array(all_labels)

print(f"Total embeddings: {len(all_embeddings)}")
print(f"Total pneumonia cases: {sum(all_labels)}")
print(f"Total non-pneumonia cases: {len(all_labels) - sum(all_labels)}")
print(f"\nEmbedding shape: {all_embeddings.shape}")
print(f"Labels shape: {all_labels.shape}")

# Split into train and test sets with stratification
try:
    train_embeddings, test_embeddings, train_labels, test_labels = train_test_split(
        all_embeddings, all_labels,
        test_size=0.2,
        random_state=42,
        stratify=all_labels
    )
    print(f"\nTrain/test split with stratification successful!")
except ValueError as e:
    print(f"\nStratification not possible: {e}")
    print("Performing split without stratification...")
    train_embeddings, test_embeddings, train_labels, test_labels = train_test_split(
        all_embeddings, all_labels,
        test_size=0.2,
        random_state=42
    )

print(f"\nTraining set:")
print(f"  Total samples: {len(train_embeddings)}")
print(f"  Pneumonia cases: {sum(train_labels)}")
print(f"  Non-pneumonia cases: {len(train_labels) - sum(train_labels)}")

print(f"\nTest set:")
print(f"  Total samples: {len(test_embeddings)}")
print(f"  Pneumonia cases: {sum(test_labels)}")
print(f"  Non-pneumonia cases: {len(test_labels) - sum(test_labels)}")

## Train Pneumonia Classifiers

In [None]:
%%time

# Define multiple classifier models
models = {
    "Support Vector Machine (linear)": SVC(kernel='linear', probability=True),
    "Logistic Regression": LogisticRegression(max_iter=1000),
    "Gradient Boosting": GradientBoostingClassifier(n_estimators=128),
    "Random Forest": RandomForestClassifier(n_estimators=128, random_state=42),
    "MLP Classifier": MLPClassifier(hidden_layer_sizes=(128, 64), max_iter=1000, random_state=42),
}

# Train each model
pneumonia_models = {}
print("Training pneumonia classifiers...\n")

for name, model in models.items():
    print(f"Training: {name}")
    model.fit(train_embeddings, train_labels)
    pneumonia_models[name] = model

    # Calculate training accuracy
    train_accuracy = model.score(train_embeddings, train_labels)
    print(f"  Training accuracy: {train_accuracy:.4f}")

print(f"\nAll {len(pneumonia_models)} models trained successfully!")

## Evaluate Pneumonia Classifiers on Test Set

In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

print(f"Evaluating {len(pneumonia_models)} models on test set ({len(test_embeddings)} samples)...\n")
print("="*80)

for model_name, pneumonia_model in pneumonia_models.items():
    # Make predictions
    predictions = pneumonia_model.predict(test_embeddings)

    # Calculate metrics
    test_accuracy = accuracy_score(test_labels, predictions)

    print(f"\n{model_name}")
    print("-" * 80)
    print(f"Test Accuracy: {test_accuracy:.4f}")
    print(f"\nClassification Report:")
    print(classification_report(test_labels, predictions,
                                target_names=['No Pneumonia', 'Pneumonia'],
                                zero_division=0))
    print(f"Confusion Matrix:")
    print(confusion_matrix(test_labels, predictions))
    print("="*80)

## Visualize Embeddings with PCA (Pneumonia Status)

In [None]:
# Fit PCA on all embeddings
pca = PCA(n_components=2)
all_embeddings_pca = pca.fit_transform(all_embeddings)

# Separate by pneumonia status
pneumonia_mask = all_labels == 1
no_pneumonia_mask = all_labels == 0

# Plot
plt.figure(figsize=(12, 8))

# Plot non-pneumonia cases
plt.scatter(all_embeddings_pca[no_pneumonia_mask, 0],
           all_embeddings_pca[no_pneumonia_mask, 1],
           color='blue', alpha=0.5, label='No Pneumonia', s=50)

# Plot pneumonia cases
plt.scatter(all_embeddings_pca[pneumonia_mask, 0],
           all_embeddings_pca[pneumonia_mask, 1],
           color='red', alpha=0.5, label='Pneumonia', s=50)

plt.xlabel(f"PCA Dimension 1 ({pca.explained_variance_ratio_[0]:.2%} variance)")
plt.ylabel(f"PCA Dimension 2 ({pca.explained_variance_ratio_[1]:.2%} variance)")
plt.title("HeAR Embeddings Visualization - Pneumonia Classification\n(PCA Projection)")
plt.legend(loc='best', fontsize=12)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Total variance explained by 2 components: {sum(pca.explained_variance_ratio_):.2%}")

## Visualize Sample Embeddings as Barcodes

In [None]:
# Select a few sample files to visualize (limit to first 5 files)
embedding_mean = np.mean(all_embeddings, axis=0)
sample_files = list(file_embeddings.keys())[:5]

print(f"Visualizing embeddings for {len(sample_files)} sample files...\n")

for file_key in sample_files:
    embedding_batch = file_embeddings[file_key]
    batch_size = embedding_batch.shape[0]

    # Get pneumonia status
    file_base = os.path.splitext(file_key)[0]
    is_pneumonia = pneumonia_labels.get(file_base, False)
    status = "PNEUMONIA" if is_pneumonia else "NO PNEUMONIA"

    # Subtract mean for visualization
    embedding_batch_norm = embedding_batch - embedding_mean

    print(f"{file_key} - {status} ({batch_size} embeddings)")

    plt.figure(figsize=(18, 1 * batch_size))
    for i in range(batch_size):
        embedding_magnitude = embedding_batch_norm[i, :] ** 2
        plt.subplot(batch_size, 1, i + 1)
        plt.imshow(embedding_magnitude.reshape(1, -1), cmap='binary', interpolation=None, aspect='auto')
        plt.title(f"Embedding {i+1}/{batch_size} - {file_key} - {status}", fontsize=10)
        plt.xticks([])
        plt.yticks([])
    plt.tight_layout()
    plt.show()

## Summary and Next Steps

This notebook has successfully transformed audio classification from cough detection to **pneumonia classification** using:

- **Dataset**: COUGHVID-v3 from KaggleHub (orvile/coughvid-v3)
- **Model**: Google HeAR (Health Acoustic Representations)
- **Labels**: Pneumonia cases identified from COVID-19 and lower respiratory infections
- **Classifiers**: SVM, Logistic Regression, Gradient Boosting, Random Forest, MLP

### Key Results:
- Processed audio files and generated HeAR embeddings
- Trained multiple classifiers on pneumonia vs. non-pneumonia cases
- Evaluated performance on held-out test set
- Visualized embedding space to show class separation

### Next Steps:
1. **Tune hyperparameters** for better model performance
2. **Balance dataset** if class imbalance is significant
3. **Add more features** or try ensemble methods
4. **Deploy models** for real-world pneumonia screening
5. **Validate** on external datasets

Explore the other [notebooks](https://github.com/google-health/hear/blob/master/notebooks) to learn what else you can do with the HeAR model.