# Import Libraries

In [1]:
# Imports
import os
import librosa
import numpy as np
import pandas as pd
import soundfile as sf
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import classification_report, accuracy_score
from catboost import CatBoostClassifier

import joblib
from datetime import datetime


# DEFINE CONSTANTS

In [None]:
# Define the emotion map
emotions = {
    '01': 'neutral',
    '02': 'calm',
    '03': 'happy',
    '04': 'sad',
    '05': 'angry',
    '06': 'fearful',
    '07': 'disgust',
    '08': 'surprised'
}


# Directory where audio clips are stored
AUDIO_DIR = "../data/model_training"

# EXTRACT FEATURES

In [3]:
def extract_feature(file_name, mfcc=True, chroma=True, mel=True):
    with sf.SoundFile(file_name) as sound_file:
        X = sound_file.read(dtype="float32")
        sample_rate = sound_file.samplerate
        result = np.array([])

        if chroma:
            stft = np.abs(librosa.stft(X))

        if mfcc:
            mfccs = np.mean(librosa.feature.mfcc(y=X, sr=sample_rate, n_mfcc=40).T, axis=0)
            result = np.hstack((result, mfccs))

        if chroma:
            chroma_vals = np.mean(librosa.feature.chroma_stft(S=stft, sr=sample_rate).T, axis=0)
            result = np.hstack((result, chroma_vals))

        if mel:
            mel_vals = np.mean(librosa.feature.melspectrogram(y=X, sr=sample_rate).T, axis=0)
            result = np.hstack((result, mel_vals))

        return result

# Store features and labels
features = []
labels = []

# Iterate through audio files
for file in os.listdir(AUDIO_DIR):
    if file.endswith(".wav"):
        try:
            # Extract emotion code from filename
            emotion_code = file.split("-")[2]
            emotion_label = emotions.get(emotion_code)

            if emotion_label is not None:
                file_path = os.path.join(AUDIO_DIR, file)
                feat = extract_feature(file_path, mfcc=True, chroma=True, mel=True)
                features.append(feat)
                labels.append(emotion_label)
            else:
                print(f"Unknown emotion code in file: {file}")

        except Exception as e:
            print(f"Error processing {file}: {e}")

# Convert to DataFrame
X = np.array(features)
y = np.array(labels)

print(f"Extracted features from {len(X)} audio files.")

Extracted features from 600 audio files.


# DATA UNDERSTANDING AND MANIPULATION

In [4]:
df = pd.DataFrame(X)
df["label"] = y
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,label
0,-693.497009,50.064392,0.571451,14.329966,3.336371,-2.540720,-4.057909,-10.711999,-7.294139,1.740189,-4.190643,1.954662,-5.247894,2.781430,-3.167567,-3.400083,-2.378032,-0.568718,-6.477533,-1.243206,-2.805425,-5.436358,-0.446875,-3.635166,-2.983727,-0.563903,-1.651019,-0.555945,-3.410184,-2.244655,-3.130589,-2.700900,-1.888213,-0.554154,-3.964598,-2.134852,-3.945771,-1.624579,-2.039907,-3.629108,...,0.000169,0.000256,0.000241,0.000429,0.000298,0.000427,0.000106,0.000059,0.000023,0.000009,0.000020,0.000011,0.000037,0.000037,0.000053,0.000038,0.000089,0.000078,0.000042,0.000065,0.000035,0.000024,0.000025,0.000029,0.000040,0.000027,0.000020,0.000015,0.000011,0.000008,0.000006,0.000004,0.000004,0.000004,0.000004,3.484883e-06,4.075517e-06,1.816080e-06,7.833277e-07,neutral
1,-635.504028,46.859524,-9.306540,7.782464,-9.412089,-1.989408,-12.161742,-4.883361,-3.250412,-2.833500,-5.138503,-2.569628,-3.077204,-5.598739,-2.394116,-2.471290,-2.046306,-0.376148,-0.547966,-0.399688,-1.076877,1.802583,-2.664821,-0.438062,-0.839944,0.895285,1.415430,0.622499,2.487288,3.137058,3.639866,3.297008,0.703012,-0.611102,0.449688,-0.371180,-0.491689,-0.907375,-0.620054,-3.137493,...,0.001620,0.000779,0.000674,0.000575,0.000615,0.000241,0.000149,0.000248,0.000193,0.000144,0.000236,0.000270,0.000099,0.000194,0.000404,0.000212,0.000308,0.000238,0.000153,0.000270,0.000130,0.000120,0.000120,0.000041,0.000043,0.000043,0.000038,0.000052,0.000050,0.000049,0.000035,0.000037,0.000041,0.000016,0.000008,4.410210e-06,8.532292e-06,9.460701e-06,5.014022e-06,neutral
2,-605.083862,51.924683,-3.406904,12.412755,-0.915094,-8.159618,-13.884240,-2.129108,-10.921885,-1.894085,-4.600322,-3.519313,-3.302905,-4.813470,-2.889497,-5.215399,-1.513640,-5.002609,-4.783944,-5.074886,-4.014137,-3.637026,-3.621915,-0.814910,-3.035481,-1.623347,-2.690664,-0.920039,-2.230682,-0.674591,0.845057,0.357810,-2.443825,-2.693513,-2.416559,1.518210,3.940207,3.606819,0.360174,-0.592614,...,0.000094,0.000151,0.000183,0.000171,0.000263,0.000126,0.000116,0.000542,0.000291,0.000099,0.000063,0.000080,0.000168,0.000128,0.000079,0.000048,0.000027,0.000038,0.000124,0.000132,0.000101,0.000160,0.000165,0.000217,0.000140,0.000219,0.000097,0.000054,0.000021,0.000021,0.000031,0.000058,0.000041,0.000056,0.000041,3.777348e-05,2.298606e-05,1.448703e-05,1.382678e-05,neutral
3,-677.658447,34.967129,-7.350281,8.442965,-9.494029,-3.458288,-14.567134,-9.458730,-5.309513,-2.019068,-8.971309,-2.284987,-1.614616,-6.364736,-2.237740,-0.266151,-1.323639,4.257756,0.021959,-0.048609,-5.269603,-2.586573,-0.803066,-0.117805,2.492227,2.970306,2.970241,0.762497,-0.372282,-1.293416,1.329679,2.962503,2.018912,0.556141,2.322719,5.812874,3.535504,0.329168,-0.234255,0.468158,...,0.000161,0.000340,0.000792,0.000891,0.000121,0.000053,0.000102,0.000106,0.000067,0.000033,0.000039,0.000162,0.000115,0.000093,0.000086,0.000110,0.000147,0.000218,0.000104,0.000080,0.000145,0.000092,0.000065,0.000044,0.000064,0.000049,0.000047,0.000095,0.000063,0.000025,0.000030,0.000025,0.000013,0.000010,0.000015,2.938529e-05,3.755880e-05,1.395798e-05,4.256830e-06,neutral
4,-693.824036,65.587341,5.120113,19.882288,3.768996,1.375945,-3.604618,-0.851656,0.442239,3.579777,0.357326,1.456788,-0.570600,3.312793,0.468876,0.242004,2.378075,0.264953,-0.480116,1.217201,-1.635459,0.521646,-0.927023,-0.608878,0.587435,1.456136,0.151707,0.070004,0.653547,0.879826,0.415578,1.166848,2.200671,-1.105553,-2.105721,-0.705193,0.681609,0.519609,-1.166736,-2.206154,...,0.000019,0.000025,0.000020,0.000026,0.000025,0.000031,0.000030,0.000023,0.000021,0.000011,0.000020,0.000015,0.000021,0.000017,0.000028,0.000027,0.000017,0.000020,0.000026,0.000024,0.000039,0.000017,0.000016,0.000016,0.000012,0.000009,0.000005,0.000004,0.000003,0.000002,0.000002,0.000003,0.000002,0.000002,0.000001,5.695222e-07,4.089248e-07,2.952644e-07,2.321282e-07,neutral
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
595,-589.108093,26.858681,-17.479113,4.390011,-12.381732,-6.286273,-22.038509,-8.076139,-4.327149,-8.230455,-7.499243,-4.583221,-6.453560,-9.817992,-9.355742,-6.717556,-6.704798,-5.899377,0.136241,2.806711,3.623627,7.093021,1.325453,0.499462,0.945769,2.876611,4.413840,0.470237,-0.285627,-0.521613,1.119759,-0.002113,-0.832818,-2.199784,-2.987162,-1.334808,-1.463171,0.061110,-1.220239,-2.217192,...,0.000538,0.000339,0.000321,0.000326,0.000323,0.000530,0.000593,0.000762,0.000989,0.001015,0.000962,0.001247,0.001286,0.001065,0.000670,0.000613,0.000463,0.000464,0.000533,0.000664,0.000478,0.000492,0.000352,0.000335,0.000350,0.000337,0.000405,0.000490,0.000329,0.000252,0.000340,0.000283,0.000210,0.000289,0.000282,2.559582e-04,1.447562e-04,7.761814e-05,7.385520e-05,surprised
596,-504.896698,39.952988,-3.060997,11.396744,-3.274184,-8.512946,-7.421699,-17.272579,-3.605226,-5.265520,-11.849092,-1.704592,-6.171687,0.008878,-5.852609,-5.810474,-1.524358,-3.779039,-4.449309,-1.440191,-4.397798,-2.543599,-1.616699,-1.089385,2.180590,1.056068,2.358412,-0.200973,-1.469126,0.428541,-0.429052,0.851503,-0.923545,-1.476289,-2.311039,0.742897,1.359477,-0.435120,-1.774554,-0.315426,...,0.022975,0.013876,0.015787,0.022120,0.015437,0.015789,0.011217,0.009579,0.012257,0.008438,0.004896,0.004510,0.001936,0.002194,0.003014,0.005012,0.004207,0.005998,0.006960,0.011521,0.012942,0.017066,0.011431,0.008303,0.004835,0.006010,0.005692,0.005847,0.007064,0.006917,0.009070,0.013681,0.008999,0.006867,0.003101,2.449021e-03,1.741198e-03,1.290890e-03,8.814579e-04,surprised
597,-528.612549,19.169508,-16.147158,-5.208413,-5.308363,-14.133931,-13.485912,-14.254867,-3.872610,-4.847421,-12.681711,-5.726952,-11.084762,-2.806641,-6.955564,3.561604,0.398212,-1.363866,0.561937,1.316976,1.761419,8.442794,3.990061,2.899348,-0.018976,0.609733,1.072685,5.026315,0.330953,0.583089,-0.176946,1.757131,2.663590,0.543257,-2.503008,1.268932,0.661094,2.103149,0.608056,0.513274,...,0.001896,0.000742,0.001231,0.001097,0.001733,0.002437,0.003279,0.002674,0.001707,0.001329,0.001536,0.001525,0.001084,0.001262,0.001617,0.001588,0.001087,0.000791,0.001343,0.001182,0.001727,0.002015,0.002203,0.002196,0.002890,0.004526,0.004490,0.003982,0.002608,0.002621,0.002181,0.002499,0.002847,0.002867,0.002940,2.533817e-03,1.232030e-03,5.744396e-04,4.125543e-04,surprised
598,-516.632385,29.370678,-14.151953,-0.688881,-4.882926,-8.434866,-11.605946,-9.844908,-8.373046,-1.798220,-7.112112,-4.787591,-7.224380,-4.206560,-5.506617,-6.537671,-3.681231,-0.316746,-2.150414,3.389361,3.209548,1.569520,2.883451,2.107571,-0.208351,2.712918,1.022809,2.078006,2.035244,0.543170,1.289818,1.393380,1.960183,1.158321,0.658743,0.466846,1.593172,2.239456,1.361393,1.174186,...,0.006238,0.010900,0.009355,0.007609,0.005153,0.004770,0.003928,0.002410,0.000993,0.000777,0.000900,0.001250,0.001421,0.002654,0.002240,0.003124,0.003420,0.002982,0.003211,0.003112,0.002999,0.003684,0.003306,0.002661,0.001708,0.001612,0.001838,0.001860,0.001357,0.001332,0.001972,0.002605,0.002026,0.001264,0.001495,1.616502e-03,1.979766e-03,1.413110e-03,3.962120e-04,surprised


In [5]:
df["label"].value_counts()

label
calm         80
happy        80
sad          80
angry        80
disgust      80
fearful      80
surprised    80
neutral      40
Name: count, dtype: int64

In [None]:
# Save processed audio to CSV
process_path = "../data/processed/audio_features"
filename = f"processed_audio_features-{datetime.now().strftime('%b_%d-%Hh_%Mm')}.csv"
df.to_csv(process_path+filename, index=False)
print(f"Audio features saved to processed folder with name: {filename}")

Audio features saved to processed folder with name: processed_audio_features-Apr_28-09h_39m.csv


# MODEL TRAINING AND EVALUATION

In [7]:
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train MLP Classifier
mlp_clf = MLPClassifier(alpha=0.001,
                    hidden_layer_sizes=(512, 256),
                    max_iter=1000,
                    learning_rate="adaptive",
                    activation="logistic",
                    solver="adam",
                    learning_rate_init=0.001,
                    random_state=42)
mlp_clf.fit(X_train, y_train)

# Evaluate
y_pred = mlp_clf.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))
print("Classification Report:\n", classification_report(y_test, y_pred))


Accuracy: 0.7583333333333333
Classification Report:
               precision    recall  f1-score   support

       angry       0.79      0.85      0.81        13
        calm       0.79      0.95      0.86        20
     disgust       1.00      0.67      0.80         9
     fearful       0.68      0.83      0.75        18
       happy       0.62      0.53      0.57        15
     neutral       0.83      0.62      0.71         8
         sad       0.80      0.53      0.64        15
   surprised       0.76      0.86      0.81        22

    accuracy                           0.76       120
   macro avg       0.78      0.73      0.75       120
weighted avg       0.77      0.76      0.75       120



In [31]:
# Train Catboost Classifier
cat_clf = CatBoostClassifier(
    iterations=350,
    learning_rate=0.1,
    depth=5,
    l2_leaf_reg=0.0001,
    random_state=42,
    verbose=50)
cat_clf.fit(X_train, y_train)

# Evaluate
y_pred = cat_clf.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))
print("Classification Report:\n", classification_report(y_test, y_pred))

0:	learn: 1.9845602	total: 92.6ms	remaining: 32.3s
50:	learn: 0.6662426	total: 4.82s	remaining: 28.3s
100:	learn: 0.2937049	total: 10.2s	remaining: 25.1s
150:	learn: 0.1361913	total: 14.7s	remaining: 19.4s
200:	learn: 0.0645394	total: 19.2s	remaining: 14.2s
250:	learn: 0.0322816	total: 24.1s	remaining: 9.5s
300:	learn: 0.0161419	total: 28.6s	remaining: 4.66s
349:	learn: 0.0079050	total: 32.6s	remaining: 0us
Accuracy: 0.7666666666666667
Classification Report:
               precision    recall  f1-score   support

       angry       0.73      0.85      0.79        13
        calm       0.77      1.00      0.87        20
     disgust       0.89      0.89      0.89         9
     fearful       0.76      0.72      0.74        18
       happy       0.75      0.60      0.67        15
     neutral       0.83      0.62      0.71         8
         sad       0.67      0.53      0.59        15
   surprised       0.78      0.82      0.80        22

    accuracy                           0.77     

In [33]:
# Save the model for later use
model = "cat"

MODEL_PATH = "../models/"
os.makedirs("../models", exist_ok=True)
filename = "emotion_model"

if model == "mlp":
    filename += f"-mlp_clf-{datetime.now().strftime("%b_%d-%Hh_%Mm")}.pkl"
    joblib.dump(mlp_clf, MODEL_PATH+filename)
elif model == "cat":
    filename += f"-cat_clf-{datetime.now().strftime("%b_%d-%Hh_%Mm")}.pkl"
    joblib.dump(cat_clf, MODEL_PATH+filename)

print(f"Model saved to {MODEL_PATH+filename}")


Model saved to ../models/emotion_model-cat_clf-Apr_28-10h_03m.pkl
