# Training Advanced Notebook

Advanced pipeline for **Mental State Detection from Speech** using:
- Dual-channel spectrogram input (Mel + IMel)
- Data augmentation (hooks provided)
- CNN + GRU hybrid model
- Keras training with callbacks, model saving, and evaluation

Notes:
- Adjust `DATA_PATH` to point to `data/processed` or your spectrogram folder.
- This notebook expects `.wav` files organized under `DATA_PATH/<label>/*.wav`.

---

## 1) Setup & Imports

In [None]:
import os, random, math, warnings
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from glob import glob
import librosa
import librosa.display

import tensorflow as tf
from tensorflow.keras import layers, models, callbacks, utils
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

warnings.filterwarnings('ignore')
sns.set()
np.random.seed(42)
random.seed(42)

## 2) Parameters & Paths — adjust to your environment

In [None]:
DATA_PATH = "data/processed"   # change if needed
SR = 16000
N_MELS = 64
WIN_LENGTH = 1024
HOP_LENGTH = 512
DURATION = 4.0    # seconds: trim/pad to this
SAMPLES = int(SR * DURATION)
BATCH_SIZE = 32
EPOCHS = 40
MODEL_DIR = "experiments/models"
os.makedirs(MODEL_DIR, exist_ok=True)


## 3) Helpers: load audio, trim/pad, augmentation hooks

In [None]:
def load_audio_fixed(path, sr=SR, samples=SAMPLES):
    y, _ = librosa.load(path, sr=sr)
    # trim or pad
    if len(y) > samples:
        start = np.random.randint(0, len(y) - samples + 1)
        y = y[start:start+samples]
    else:
        y = np.pad(y, (0, max(0, samples - len(y))), mode='constant')
    return y

def augment_audio_hook(y, sr=SR):
    # Hook for augmentation (apply randomly when building dataset)
    fn = random.choice(['none', 'pitch', 'stretch', 'noise', 'shift'])
    if fn == 'pitch':
        return librosa.effects.pitch_shift(y, sr, n_steps=random.uniform(-2,2))
    if fn == 'stretch':
        rate = random.uniform(0.9, 1.1)
        try:
            return librosa.effects.time_stretch(y, rate)
        except Exception:
            return y
    if fn == 'noise':
        noise_amp = 0.005 * np.random.uniform() * np.max(np.abs(y))
        return y + noise_amp * np.random.normal(size=y.shape)
    if fn == 'shift':
        shift = int(random.uniform(-0.1, 0.1) * sr)
        return np.roll(y, shift)
    return y

## 4) Spectrogram functions — Mel + IMel (IMel = Mel filterbank reversed to emphasize high freq)

We compute standard Mel-spectrogram and an IMel spectrogram by applying a reversed Mel filterbank to the power spectrogram. Both are log-compressed and resized to the same shape for model input.

In [None]:
def make_mel_spectrogram(y, sr=SR, n_mels=N_MELS, n_fft=WIN_LENGTH, hop_length=HOP_LENGTH):
    S = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
    S_power = np.abs(S)**2
    mel_fb = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels)
    mel_spec = np.dot(mel_fb, S_power)
    mel_db = librosa.power_to_db(mel_spec)
    return mel_db

def make_imel_spectrogram(y, sr=SR, n_mels=N_MELS, n_fft=WIN_LENGTH, hop_length=HOP_LENGTH):
    # inverse emphasis mel: create mel filters and flip them to emphasize higher bins
    S = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
    S_power = np.abs(S)**2
    mel_fb = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels)
    imel_fb = mel_fb[::-1, :]
    imel_spec = np.dot(imel_fb, S_power)
    imel_db = librosa.power_to_db(imel_spec)
    return imel_db

def normalize_spectrogram(spec):
    # per-example normalization to 0-1
    spec = spec - spec.min()
    if spec.max() > 0:
        spec = spec / spec.max()
    return spec


## 5) Build dataset (spectrogram pairs) — careful with memory; use a generator for big datasets
Below we build arrays directly for simplicity. For large datasets implement `tf.data.Dataset` or a Keras generator.

In [None]:
def build_spectrogram_dataset(data_path=DATA_PATH, augment=False, max_files_per_label=None):
    X = []  # will hold (n_mels, frames, 2) arrays
    y = []
    labels = sorted([d for d in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, d))])
    for label in labels:
        files = glob(os.path.join(data_path, label, "*.wav"))
        if max_files_per_label:
            files = files[:max_files_per_label]
        for f in files:
            y_audio = load_audio_fixed(f)
            if augment and random.random() < 0.5:
                y_audio = augment_audio_hook(y_audio)
            mel = make_mel_spectrogram(y_audio)
            imel = make_imel_spectrogram(y_audio)
            # resize or pad frames to fixed width
            # ensure same frames dimension
            min_frames = min(mel.shape[1], imel.shape[1])
            mel = mel[:, :min_frames]
            imel = imel[:, :min_frames]
            # normalize
            mel = normalize_spectrogram(mel)
            imel = normalize_spectrogram(imel)
            pair = np.stack([mel, imel], axis=-1)  # shape: (n_mels, frames, 2)
            X.append(pair.astype('float32'))
            y.append(label)
    X = np.array(X)
    le = LabelEncoder()
    y_enc = le.fit_transform(y)
    return X, y_enc, le

# Build dataset (small subset for example) — remove max_files_per_label or increase for real training
X, y, label_encoder = build_spectrogram_dataset(DATA_PATH, augment=True, max_files_per_label=200)
print('X shape:', X.shape, 'y shape:', y.shape)


## 6) Prepare inputs for Keras: transpose to channels-last and optionally resize frames
Keras expects `(batch, height, width, channels)` where we choose `height = n_mels`, `width = frames`, `channels = 2`.

In [None]:
# Optionally resize to fixed width (frames) for consistent shapes.
def pad_frames(X, target_width=None):
    if target_width is None:
        # choose median width
        widths = [x.shape[1] for x in X]
        target_width = int(np.median(widths))
    X_resized = []
    for x in X:
        h, w, c = x.shape
        if w < target_width:
            pad = np.zeros((h, target_width - w, c), dtype=x.dtype)
            x2 = np.concatenate([x, pad], axis=1)
        else:
            x2 = x[:, :target_width, :]
        X_resized.append(x2)
    return np.array(X_resized), target_width

X, target_width = pad_frames(X)
print('After padding, X shape:', X.shape)

num_classes = len(np.unique(y))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
y_train_cat = utils.to_categorical(y_train, num_classes)
y_test_cat = utils.to_categorical(y_test, num_classes)


## 7) Model: Dual-channel CNN -> GRU -> Dense
We build a small, well-regularized model that takes `(n_mels, frames, 2)` input.

In [None]:
def build_cnn_gru(input_shape, num_classes):
    inp = layers.Input(shape=input_shape)
    x = layers.Conv2D(32, (3,3), padding='same', activation='relu')(inp)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPool2D((2,2))(x)
    x = layers.Conv2D(64, (3,3), padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPool2D((2,2))(x)
    # collapse frequency dimension and keep time axis
    # current shape (batch, n_mels/4, frames/4, channels)
    shape = tf.keras.backend.int_shape(x)
    # permute to (batch, time, features)
    x = layers.Permute((2,1,3))(x)
    x = layers.Reshape((shape[2], shape[1]*shape[3]))(x)
    x = layers.Bidirectional(layers.GRU(128, return_sequences=False))(x)
    x = layers.Dropout(0.4)(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    out = layers.Dense(num_classes, activation='softmax')(x)
    model = models.Model(inputs=inp, outputs=out)
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

input_shape = X_train.shape[1:]
model = build_cnn_gru(input_shape, num_classes)
model.summary()


## 8) Callbacks & Training
We save the best model and use early stopping.

In [None]:
checkpoint_path = os.path.join(MODEL_DIR, 'best_model.h5')
cb = [
    callbacks.ModelCheckpoint(checkpoint_path, monitor='val_accuracy', save_best_only=True, verbose=1),
    callbacks.EarlyStopping(monitor='val_loss', patience=7, restore_best_weights=True, verbose=1)
]

history = model.fit(
    X_train, y_train_cat,
    validation_split=0.1,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=cb,
    verbose=2
)


## 9) Evaluation: test set metrics and confusion matrix

In [None]:
model.load_weights(checkpoint_path)
y_prob = model.predict(X_test)
y_pred = np.argmax(y_prob, axis=1)
print(classification_report(y_test, y_pred, target_names=label_encoder.classes_))
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=label_encoder.classes_, yticklabels=label_encoder.classes_)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()


## 10) Save model and label encoder
We already saved best weights; optionally save full model and encoder.

In [None]:
model.save(os.path.join(MODEL_DIR, 'full_model.h5'))
import joblib
joblib.dump(label_encoder, os.path.join(MODEL_DIR, 'label_encoder.pkl'))
print('Saved full_model.h5 and label_encoder.pkl')


## Next steps / Tips
- For large datasets convert `build_spectrogram_dataset` into a streaming `tf.data.Dataset` pipeline.  
- Try data balancing, class weights, or focal loss for imbalanced classes.  
- Experiment with different spectrogram sizes, delta channels (deltas/delta-deltas), or multimodal (text) fusion.  
- Use cross-validation and keep experiments reproducible (seed RNGs, log hyperparams).