# Multi-Output CNN for Speech Recognition and Gender Classification

This notebook implements a **Multi-Output CNN** that predicts both digit and gender simultaneously:

## 🧠 Methodology
1. **Feature Extraction**
   - Extract Mel-Frequency Cepstral Coefficients (MFCCs) from audio recordings using Librosa
   - Standardize input shape by padding/truncating sequences to a fixed length

2. **Data Preparation**
   - Convert audio into numpy arrays of shape (n_mfcc, time, 1)
   - Encode labels:
     - Gender → one-hot vector ([1,0] = male, [0,1] = female)
     - Digit → one-hot vector of length 10

3. **Model Architecture (Multi-Output CNN)**
   - Convolutional Neural Network (CNN) extracts time–frequency patterns
   - Shared convolutional layers → two output branches:
     - Gender classifier (softmax over 2 classes)
     - Digit classifier (softmax over 10 classes)

4. **Training & Evaluation**
   - Train/test split = 70/30
   - Loss function = categorical crossentropy (for both tasks)
   - Metrics = accuracy, precision, recall, F1-score
   - Visualization of training history (accuracy/loss curves)

## Dataset Structure
```
Dataset/
   d0/ (digit 0)
      male/   → male speakers saying "zero"
      female/ → female speakers saying "zero"
   d1/ (digit 1)
      male/   → male speakers saying "one"
      female/ → female speakers saying "one"
   ...
   d9/ (digit 9)
      male/   → male speakers saying "nine"
      female/ → female speakers saying "nine"
```

**Note**: We predict BOTH digit (0-9) AND gender (male/female) simultaneously using a Multi-Output CNN.


### 1. Imports and Setup

In [None]:
import os
import sys
import warnings
warnings.filterwarnings('ignore')

# Add src to path
src_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'src'))
if src_path not in sys.path:
    sys.path.append(src_path)

# Third-party imports
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
import librosa
import librosa.display
from sklearn.metrics import classification_report, confusion_matrix

# Project imports
from prepare_dataset import load_dataset
from train_model import build_model
from extract_features import extract_mfcc
from utils import plot_history

# Plot style
plt.style.use('seaborn-v0_8')
sns.set_palette('husl')

# Reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print('✅ Imports complete. Ready to run the pipeline.')

### 2. Configuration

In [None]:
DATASET_PATH = 'dataset/'  # must match src expectations
EPOCHS = 20
BATCH_SIZE = 32

# Ensure output directories exist
os.makedirs('plots', exist_ok=True)
os.makedirs('models', exist_ok=True)

print('DATASET_PATH:', DATASET_PATH)
print('EPOCHS:', EPOCHS, 'BATCH_SIZE:', BATCH_SIZE)

### 3. Dataset Analysis


In [None]:
digits = [f'd{i}' for i in range(10)]
genders = ['male', 'female']

digit_counts = {d: 0 for d in digits}
gender_counts = {g: 0 for g in genders}
total_files = 0

for d in digits:
    digit_dir = os.path.join(DATASET_PATH, d)
    if not os.path.isdir(digit_dir):
        continue
    for g in genders:
        gdir = os.path.join(digit_dir, g)
        if not os.path.isdir(gdir):
            continue
        wavs = [f for f in os.listdir(gdir) if f.lower().endswith('.wav')]
        digit_counts[d] += len(wavs)
        gender_counts[g] += len(wavs)
        total_files += len(wavs)

print('Total audio files:', total_files)
print('Gender counts:', gender_counts)
print('Digit counts:', digit_counts)

# Visualization
if total_files > 0:
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Gender pie
    axes[0].pie([gender_counts.get('male', 0), gender_counts.get('female', 0)],
               labels=['Male', 'Female'], autopct='%1.1f%%',
               colors=['lightblue', 'lightpink'], startangle=90)
    axes[0].set_title('Gender Distribution')
    
    # Digit bar
    d_labels = list(digit_counts.keys())
    d_values = [digit_counts[k] for k in d_labels]
    bars = axes[1].bar(d_labels, d_values, alpha=0.8, color='lightgreen')
    axes[1].set_title('Digit Distribution')
    axes[1].set_ylabel('Samples')
    axes[1].tick_params(axis='x', rotation=45)
    for bar, count in zip(bars, d_values):
        axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, str(count),
                     ha='center', va='bottom')
    
    # Summary
    axes[2].axis('off')
    axes[2].text(0.5, 0.7, f'Total Files: {total_files}', ha='center', va='center',
                 fontsize=14, fontweight='bold', transform=axes[2].transAxes)
    axes[2].text(0.5, 0.5, f"Male: {gender_counts.get('male', 0)}", ha='center', va='center',
                 fontsize=12, transform=axes[2].transAxes)
    axes[2].text(0.5, 0.35, f"Female: {gender_counts.get('female', 0)}", ha='center', va='center',
                 fontsize=12, transform=axes[2].transAxes)

    plt.tight_layout()
    plt.show()
else:
    print('No files found. Please ensure dataset is available at', DATASET_PATH)

### 4. MFCC Feature Extraction Demo

In [None]:
sample_file = None
for d in [f'd{i}' for i in range(10)]:
    male_dir = os.path.join(DATASET_PATH, d, 'male')
    female_dir = os.path.join(DATASET_PATH, d, 'female')
    for gdir in [male_dir, female_dir]:
        if os.path.isdir(gdir):
            wavs = [f for f in os.listdir(gdir) if f.lower().endswith('.wav')]
            if wavs:
                sample_file = os.path.join(gdir, wavs[0])
                break
    if sample_file:
        break

if sample_file is None:
    print('❌ No sample audio found in dataset. Skipping demo.')
else:
    print('Sample file:', sample_file)
    mfcc = extract_mfcc(sample_file)
    if mfcc is None:
        print('Failed to extract MFCC for sample.')
    else:
        print('MFCC shape:', mfcc.shape)
        # Load raw audio for waveform
        y, sr = librosa.load(sample_file, sr=16000)
        fig, axes = plt.subplots(1, 2, figsize=(14, 4))
        # Waveform
        axes[0].plot(y)
        axes[0].set_title('Waveform')
        axes[0].set_xlabel('Samples')
        axes[0].set_ylabel('Amplitude')
        # MFCC heatmap
        img = axes[1].imshow(mfcc, aspect='auto', origin='lower', cmap='viridis')
        axes[1].set_title('MFCC (normalized)')
        axes[1].set_xlabel('Time Frames')
        axes[1].set_ylabel('MFCC Coefficients')
        fig.colorbar(img, ax=axes[1], shrink=0.8)
        plt.tight_layout()
        plt.show()

### 5. Load Dataset and Build Model

In [None]:
# Load dataset (70/30 split is handled inside load_dataset)
X_train, X_test, y_train, y_test = load_dataset(DATASET_PATH)

print('X_train:', X_train.shape, 'X_test:', X_test.shape)
print('y_train (digit, gender):', [arr.shape for arr in y_train])
print('y_test (digit, gender):', [arr.shape for arr in y_test])

# Build model
input_shape = X_train.shape[1:]
model = build_model(input_shape)
model.summary()

### 6. Train Model


In [None]:
history = model.fit(
    X_train, y_train,
    validation_data=(X_test, y_test),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    verbose=1
)

# Show training curves
plot_history(history)

# Save training curves
h = history.history

# Accuracy plots
plt.figure(figsize=(8,5))
if 'digit_output_accuracy' in h:
    plt.plot(h['digit_output_accuracy'], label='train_digit')
if 'val_digit_output_accuracy' in h:
    plt.plot(h['val_digit_output_accuracy'], label='val_digit')
if 'gender_output_accuracy' in h:
    plt.plot(h['gender_output_accuracy'], label='train_gender', linestyle='--')
if 'val_gender_output_accuracy' in h:
    plt.plot(h['val_gender_output_accuracy'], label='val_gender', linestyle='--')
plt.title('Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join('plots', 'training_accuracy.png'))
plt.show()

# Loss plot
plt.figure(figsize=(8,5))
if 'loss' in h:
    plt.plot(h['loss'], label='train_loss')
if 'val_loss' in h:
    plt.plot(h['val_loss'], label='val_loss')
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join('plots', 'training_loss.png'))
plt.show()

### 7. Evaluation

In [None]:
# Predictions (model has two outputs)
preds = model.predict(X_test)
if isinstance(preds, list) and len(preds) == 2:
    y_pred_digit_proba, y_pred_gender_proba = preds
else:
    raise RuntimeError('Model did not return two outputs on predict()')

y_true_digit = np.argmax(y_test[0], axis=1)
y_true_gender = np.argmax(y_test[1], axis=1)
y_pred_digit = np.argmax(y_pred_digit_proba, axis=1)
y_pred_gender = np.argmax(y_pred_gender_proba, axis=1)

# Classification reports
digit_report = classification_report(y_true_digit, y_pred_digit)
gender_report = classification_report(y_true_gender, y_pred_gender)
print('Digit Classification Report:\\n', digit_report)
print('Gender Classification Report:\\n', gender_report)

# Save reports
with open(os.path.join('plots', 'classification_reports.txt'), 'w') as f:
    f.write('Digit Classification Report:\\n')
    f.write(digit_report + '\\n\\n')
    f.write('Gender Classification Report:\\n')
    f.write(gender_report + '\\n')

# Confusion matrix - digits
cm_digits = confusion_matrix(y_true_digit, y_pred_digit)
plt.figure(figsize=(8,6))
plt.imshow(cm_digits, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix - Digits')
plt.colorbar()
ticks = np.arange(cm_digits.shape[0])
plt.xticks(ticks, ticks)
plt.yticks(ticks, ticks)
plt.xlabel('Predicted')
plt.ylabel('True')
thresh = cm_digits.max() / 2.0
for i in range(cm_digits.shape[0]):
    for j in range(cm_digits.shape[1]):
        plt.text(j, i, format(cm_digits[i, j], 'd'),
                 horizontalalignment='center',
                 color='white' if cm_digits[i, j] > thresh else 'black')
plt.tight_layout()
plt.savefig(os.path.join('plots', 'confusion_digits.png'))
plt.show()

# Confusion matrix - gender
cm_gender = confusion_matrix(y_true_gender, y_pred_gender)
plt.figure(figsize=(4,4))
plt.imshow(cm_gender, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix - Gender')
plt.colorbar()
plt.xticks([0,1], [0,1])
plt.yticks([0,1], [0,1])
plt.xlabel('Predicted')
plt.ylabel('True')
thresh = cm_gender.max() / 2.0
for i in range(cm_gender.shape[0]):
    for j in range(cm_gender.shape[1]):
        plt.text(j, i, format(cm_gender[i, j], 'd'),
                 horizontalalignment='center',
                 color='white' if cm_gender[i, j] > thresh else 'black')
plt.tight_layout()
plt.savefig(os.path.join('plots', 'confusion_gender.png'))
plt.show()

### 8. Save Model

In [None]:
model_path = os.path.join('models', 'gender_digit_classifier.h5')
model.save(model_path)
print('✅ Model saved at:', model_path)
print('You can now use src/interface.py or src/gender_classifier.py for inference.')