# Tone Classifier

### Imports

In [2]:
import os
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, WeightedRandomSampler  
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import accuracy_score, classification_report

from data_loader import Wav2Vec2FeatureExtractor, Wav2Vec2Dataset, collate_fn_wav2vec2

  from .autonotebook import tqdm as notebook_tqdm


### Datasets
CREMA D + RAVDESS

In [3]:
# CREMA-D data
crema_dir = 'data/CREMA'
crema_files = [f for f in os.listdir(crema_dir) if f.endswith('.wav')]
def extract_crema_label(filename):
    return filename.split('_')[2].replace('.wav', '').lower()

df_crema = pd.DataFrame({
    'file': [os.path.join(crema_dir, f) for f in crema_files],
    'label': [extract_crema_label(f) for f in crema_files]
})

df_crema.head()

Unnamed: 0,file,label
0,data/CREMA\1001_DFA_ANG_XX.wav,ang
1,data/CREMA\1001_DFA_DIS_XX.wav,dis
2,data/CREMA\1001_DFA_FEA_XX.wav,fea
3,data/CREMA\1001_DFA_HAP_XX.wav,hap
4,data/CREMA\1001_DFA_NEU_XX.wav,neu


In [4]:
# RAVDESS data
ravdess_dir = 'data/RAVDESS/'

ravdess_emotion_map = {
    "01": "neutral",
    "02": "calm",
    "03": "happy",
    "04": "sad",
    "05": "angry",
    "06": "fearful",
    "07": "disgust",
    "08": "surprised"
}

ravdess_files = []
for root, _, files in os.walk(ravdess_dir):
    for f in files:
        if f.endswith('.wav'):
            full_path = os.path.join(root, f)
            ravdess_files.append(full_path)


def extract_ravdess_label(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('-')
    emotion_code = parts[2]
    return ravdess_emotion_map.get(emotion_code)

df_ravdess = pd.DataFrame({
    'file': ravdess_files,
    'label': [extract_ravdess_label(f) for f in ravdess_files]
})

# Filter out RAVDESS classes not in CREMA “calm” and “surprised”
df_ravdess = df_ravdess[df_ravdess["label"].isin(["angry","disgust","fearful","happy","neutral","sad"])]

# Drop  rows with unknown emotion code (just in case)
df_ravdess = df_ravdess.dropna()
df_ravdess.head()

Unnamed: 0,file,label
0,data/RAVDESS/Actor_01\03-01-01-01-01-01-01.wav,neutral
1,data/RAVDESS/Actor_01\03-01-01-01-01-02-01.wav,neutral
2,data/RAVDESS/Actor_01\03-01-01-01-02-01-01.wav,neutral
3,data/RAVDESS/Actor_01\03-01-01-01-02-02-01.wav,neutral
12,data/RAVDESS/Actor_01\03-01-03-01-01-01-01.wav,happy


In [5]:
# Combine CREMA and RAVDESS
label_map_crema = {
    'ANG': 'angry',
    'DIS': 'disgust',
    'FEA': 'fearful',
    'HAP': 'happy',
    'NEU': 'neutral',
    'SAD': 'sad'
}
df_crema['label'] = df_crema['label'].map(lambda l: label_map_crema.get(l.upper()))

# Combine
df_combined = pd.concat([df_crema, df_ravdess], ignore_index=True)
df_combined.head()

Unnamed: 0,file,label
0,data/CREMA\1001_DFA_ANG_XX.wav,angry
1,data/CREMA\1001_DFA_DIS_XX.wav,disgust
2,data/CREMA\1001_DFA_FEA_XX.wav,fearful
3,data/CREMA\1001_DFA_HAP_XX.wav,happy
4,data/CREMA\1001_DFA_NEU_XX.wav,neutral


In [6]:
# Build label array
labels = df_combined["label"].values
le = LabelEncoder()
y_all = le.fit_transform(labels)  # numeric

In [None]:
# embeddings for each file
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

feature_extractor = Wav2Vec2FeatureExtractor(model_name="facebook/wav2vec2-base", device=device)

all_embeddings = []
for i, row in df_combined.iterrows():
    audio_path = row["file"]
    emb = feature_extractor.extract_embedding(audio_path)
    all_embeddings.append(emb)

all_embeddings = np.array(all_embeddings)  # shape [num_samples, 768]

print("Embeddings shape:", all_embeddings.shape)

Using device: cuda




Embeddings shape: (9554, 768)


In [8]:
# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
    all_embeddings, y_all, test_size=0.25, stratify=y_all, random_state=42
)

## Tone Classification model

MLP trained on mean-pooled Wav2Vec2 audio embeddings and 5-fold cross-validation

In [None]:
n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
input_dim = all_embeddings.shape[1]
output_dim = len(np.unique(y_all))
batch_size = 16
epochs = 50
y_train_array = np.array(y_train)


fold_accuracies = []
best_acc = 0.0
best_model_state = None

# MLP classifier
class ToneClassifierModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Dropout(0.4),

            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.GELU(),
            nn.Dropout(0.3),

            nn.Linear(256, output_dim)
        )
    def forward(self, x):
        return self.net(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for fold, (train_idx, val_idx) in enumerate(skf.split(X_train, y_train_array)):
    print(f"\n=== Fold {fold+1}/{n_splits} ===")

    X_fold_train = X_train[train_idx]
    y_fold_train = y_train_array[train_idx]
    X_fold_val = X_train[val_idx]
    y_fold_val = y_train_array[val_idx]
    
    train_dataset = Wav2Vec2Dataset(X_fold_train, y_fold_train)
    val_dataset = Wav2Vec2Dataset(X_fold_val, y_fold_val)

    # WeightedRandomSampler
    class_counts = np.bincount(y_fold_train)
    weights_per_class = [len(y_fold_train)/c for c in class_counts]
    sample_weights = [weights_per_class[label] for label in y_fold_train]
    sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

    # Create DataLoaders with WeightedRandomSampler for train
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=sampler,  # No shuffle here
        collate_fn=collate_fn_wav2vec2
    )

    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn_wav2vec2)

    model = ToneClassifierModel(input_dim, output_dim).to(device)

    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=7, factor=0.5)

    best_val_loss = float("inf")
    patience = 5
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        train_loss_sum = 0.0
        train_preds, train_labels = [], []

        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            logits = model(xb)
            loss = criterion(logits, yb)
            loss.backward()
            optimizer.step()
            train_loss_sum += loss.item()

            preds = torch.argmax(logits, dim=1)
            train_preds.extend(preds.cpu().numpy())
            train_labels.extend(yb.cpu().numpy())

        train_loss = train_loss_sum / len(train_loader)
        train_acc = accuracy_score(train_labels, train_preds)


        model.eval()
        val_loss_sum = 0.0
        all_preds, all_labels = [], []
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                logits = model(xb)
                loss = criterion(logits, yb)
                val_loss_sum += loss.item()
                preds = torch.argmax(logits, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(yb.cpu().numpy())
        val_loss = val_loss_sum / len(val_loader)
        val_acc = accuracy_score(all_labels, all_preds)
        scheduler.step(val_loss)

        print(f"Epoch {epoch+1}/{epochs}, "
            f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
            f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")


        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            best_state_dict = model.state_dict()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    # Evaluate best state on val set
    model.load_state_dict(best_state_dict)
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(yb.cpu().numpy())

    fold_acc = accuracy_score(all_labels, all_preds)
    fold_accuracies.append(fold_acc)

    if fold_acc > best_acc:
        best_acc = fold_acc
        best_model_state = best_state_dict

    print("Fold Classification Report:\n",
          classification_report(all_labels, all_preds, target_names=le.classes_))

print(f"\n----- Average Accuracy over {n_splits} folds: {np.mean(fold_accuracies):.4f} ± {np.std(fold_accuracies):.4f}")



=== Fold 1/5 ===
Epoch 1/50, Train Loss: 1.4541, Train Acc: 0.4131, Val Loss: 1.3333, Val Acc: 0.4738
Epoch 2/50, Train Loss: 1.3521, Train Acc: 0.4525, Val Loss: 1.2383, Val Acc: 0.5059
Epoch 3/50, Train Loss: 1.2870, Train Acc: 0.4894, Val Loss: 1.2406, Val Acc: 0.5101
Epoch 4/50, Train Loss: 1.2496, Train Acc: 0.5120, Val Loss: 1.1922, Val Acc: 0.5318
Epoch 5/50, Train Loss: 1.2344, Train Acc: 0.5152, Val Loss: 1.1630, Val Acc: 0.5457
Epoch 6/50, Train Loss: 1.1961, Train Acc: 0.5337, Val Loss: 1.1577, Val Acc: 0.5520
Epoch 7/50, Train Loss: 1.1786, Train Acc: 0.5363, Val Loss: 1.1519, Val Acc: 0.5485
Epoch 8/50, Train Loss: 1.1359, Train Acc: 0.5663, Val Loss: 1.1669, Val Acc: 0.5373
Epoch 9/50, Train Loss: 1.1118, Train Acc: 0.5645, Val Loss: 1.1552, Val Acc: 0.5401
Epoch 10/50, Train Loss: 1.1036, Train Acc: 0.5715, Val Loss: 1.1201, Val Acc: 0.5597
Epoch 11/50, Train Loss: 1.0680, Train Acc: 0.5790, Val Loss: 1.1366, Val Acc: 0.5618
Epoch 12/50, Train Loss: 1.0765, Train Acc: 0

### Evaluate on test set

In [None]:
# Evaluate on Test set with best fold model
print("\n----- Evaluate best fold on Test Set -----")
test_dataset = Wav2Vec2Dataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn_wav2vec2)

final_model = ToneClassifierModel(input_dim, output_dim).to(device)
final_model.load_state_dict(best_model_state)

# Save model
torch.save({
    'model_state_dict': final_model.state_dict(),
    'label_encoder': le.classes_,  # Save label classes
}, 'saved_models/tone_classifier.pth')


final_model.eval()

all_test_preds, all_test_labels = [], []
with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = final_model(xb)
        preds = torch.argmax(logits, dim=1)
        all_test_preds.extend(preds.cpu().numpy())
        all_test_labels.extend(yb.cpu().numpy())

test_acc = accuracy_score(all_test_labels, all_test_preds)
print(f"Test Accuracy: {test_acc:.4f}")
print("Test Classification Report:\n",
      classification_report(all_test_labels, all_test_preds, target_names=le.classes_))


----- Evaluate best fold on Test Set -----
Test Accuracy: 0.5823
Test Classification Report:
               precision    recall  f1-score   support

       angry       0.64      0.80      0.71       414
     disgust       0.55      0.57      0.56       414
     fearful       0.54      0.54      0.54       414
       happy       0.56      0.41      0.48       414
     neutral       0.57      0.63      0.60       319
         sad       0.60      0.55      0.57       414

    accuracy                           0.58      2389
   macro avg       0.58      0.58      0.58      2389
weighted avg       0.58      0.58      0.58      2389

(9554, 768)


### Test the model on our own data:


In [None]:
# Test the model on our own data
def predict_emotion(audio_path, model_path='saved_models/tone_classifier.pth'):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load model and label encoder
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)

    label_classes = checkpoint['label_encoder']
    output_dim = len(label_classes)

    input_dim = 768
    model = ToneClassifierModel(input_dim=input_dim, output_dim=output_dim).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    # Extract Wav2Vec2 embedding
    extractor = Wav2Vec2FeatureExtractor(model_name="facebook/wav2vec2-base", device=device)
    emb = extractor.extract_embedding(audio_path)  # shape [768]

    # Convert to tensor
    emb_tensor = torch.tensor(emb).unsqueeze(0).to(device)

    # Predict with MLP
    with torch.no_grad():
        logits = model(emb_tensor)
        pred_idx = torch.argmax(logits, dim=1).item()
        predicted_emotion = label_classes[pred_idx]

    return predicted_emotion

In [13]:
test_path = "data/Custom/01_Sahba_ANG.wav"
predicted = predict_emotion(test_path)
print("Predicted Emotion:", predicted)

Predicted Emotion: disgust
