This project focuses on detecting infant cries, screams, and normal utterances from audio data. We employ two distinct deep learning architectures—YAMNet (based on a MobileNet feature extractor for audio) and Wav2Vec2 (a transformer-based model from the Hugging Face ecosystem)—and combine their predictions via an ensemble approach. The end goal is to deploy this system in a Temporal workflow that can handle real-time audio streams.
According to the assignment:
- Select an appropriate number of samples in the experimental (cry, scream) and control (speech, music, other sounds) datasets.
- Preprocess audio data to ensure consistency (e.g., sampling rate, bit depth).
- Train two models:
- Fine-tuned YAMNet
- Fine-tuned Wav2Vec2
- Develop an Ensemble of the two models.
- Evaluate models using metrics like accuracy, precision, recall, F1-score, confusion matrices, ROC curves, etc.
- Deploy the final ensemble in a Temporal workflow to manage tasks such as:
- Receiving audio input
- Running the classification
- Storing/managing results
- Deliverables include code, readme, training graphs, and inference instructions.
my_project/
│
├── data_manifest.csv
├── Dockerfile
├── docker-compose.yml
├── requirements.txt
├── README.md
│
├── src/
│ ├── init.py
│ ├── preprocess.py
│ ├── training.py
│ ├── training_wave2vec2.py
│ ├── testing_YAMNet.py
│ ├── ensemblemodel.py # Contains the final ensemble inference logic
│
├── models/
│ ├── yamnet_finetuned.h5
│ ├── wav2vec2_finetuned/ # Directory containing PyTorch model checkpoint
│
├── preprocessed_data/
│ └── preprocessed_audio
│
├── temporal_workflow/
│ ├── worker.py # Defines Temporal worker & activities
│ ├── workflow.py # Defines Temporal workflow
│ └── client.py # Triggers workflow executions
│
└── plots/
├── classification_report_yamnet.pdf
├── ensemble_classification_report.pdf
├── ensemble_roc_curves.pdf
├── loss_accuracy_curves_yamnet.pdf
├── loss_accuracy_curves_wav2vec2.pdf
└── … \
--- \
We use several open-source datasets for:
- Infant Cry: AudioSet 1, AudioSet 2
- Screams: AudioSet 3
- Normal Utterances: Common Voice Dataset
- Consistent Sampling Rate: We resample all audio to 16kHz.
- Segmentation: Each audio file is segmented into short clips (5-15 seconds).
- Labeling: Segments are labeled as
crying,screaming, ornormal. - Metadata: A CSV (
data_manifest.csv) keeps track offile_pathandlabel.
- We aimed for a balanced dataset with enough examples of each class (cry, scream, normal speech).
- Experimental: cry and scream.
- Control: normal speech, music, ambient indoor sounds.
- To ensure each class is well-represented, we used stratified sampling to select training/validation/test splits. This helps maintain class distribution in each split.
- Base Model: YAMNet which is pretrained on AudioSet.
- Fine-Tuning: We replace the classification layer to output 3 classes and freeze/unfreeze select layers.
- Implementation: In
src/training.py(or a separate script), we load YAMNet from TF Hub, adapt final layers, and train on our labeled segments.
- Base Model: Wav2Vec2ForSequenceClassification.
- Fine-Tuning: We load a pretrained checkpoint and adapt the final classification head for 3 classes.
- Implementation: In
src/training_wave2vec2.py, we use the Hugging Face Trainer or a custom training loop, specifying the same labeled data.
Note: Both trainings produce their own best model checkpoints:
models/yamnet_finetuned.h5models/wav2vec2_finetuned/(directory with PyTorch checkpoints)
The assignment recommends:
- Averaging Probabilities
- Majority Voting
- Meta-classifier (e.g., training a small network on top of the outputs)
We chose probability averaging as it was straightforward and yielded strong performance:
- Obtain softmax probabilities from YAMNet.
- Obtain softmax probabilities from Wav2Vec2.
- Average them element-wise:
[ P_\text{combined} = \frac{P_\text{YAMNet} + P_\text{Wav2Vec2}}{2} ] - Classify using
argmax(P_combined).
Implementation can be found in src/ensemblemodel.py.
- 70% for training
- 15% for validation
- 15% for testing
- Ensured stratified distribution across cry, scream, normal classes.
- We monitor validation loss and accuracy.
- Use early stopping to avoid overfitting.
- Final models are tested on the unseen 15% test set.
- We compute accuracy, precision, recall, F1-score for each class.
- Additionally, we generate confusion matrices and ROC curves.
- See
plots/ensemble_classification_report.pdffor a summary of the ensemble classification metrics and confusion matrix. - For individual models, see:
plots/classification_report_yamnet.pdfplots/ensemble_roc_curves.pdffor ROC curves of the ensemble approach.
- The ensemble ROC curves are in
plots/ensemble_roc_curves.pdf. Each class (cry, scream, normal) is plotted separately.
- Wav2Vec2 model accuracy-91%
- YAMNet model accuracy - 70%
- Ensemble Model accuracy - 97%
Overall, the ensemble model typically outperforms individual models, showing higher recall for cry and scream classes and balanced precision for normal utterances.
For multi-class classification (cry, scream, normal), we use categorical cross-entropy (in Keras) or cross-entropy loss (in PyTorch). This is suitable because:
- Probabilistic Outputs: We want each model to output probabilities over the 3 classes.
- One-hot Targets: Each audio segment belongs to exactly one class.
- Experimental vs. Control: The data has two “experimental” classes (cry, scream) and one “control” class (normal speech/music/other). Cross-entropy is standard for such tasks.
The assignment requires designing a Temporal workflow to:
- Receive and preprocess audio input (5–15 seconds each).
- Run the ensemble model for classification.
- Store and manage results (e.g., logs, metrics).
Workflow Steps:
workflow.py: Defines the workflow structure (what tasks to call in sequence).worker.py: Defines the worker that executes the activities (preprocessing, inference).client.py: Submits workflow executions.
- The assignment states the workflow should handle real-time audio streams of at least 5 seconds and up to 15 seconds. In practice, you can:
- Stream or chunk audio to the workflow.
- Each chunk triggers an activity that runs the ensemble model.