In [1]:
import math

import librosa
import numpy as np

def get_mfcc(file_path):
    y, sr = librosa.load(file_path)  # read .wav file
    hop_length = math.floor(sr * 0.010)  # 10ms hop
    win_length = math.floor(sr * 0.025)  # 25ms frame
    # mfcc is 12 x T matrix
    mfcc = librosa.feature.mfcc(
        y=y, sr=sr, n_mfcc=12, n_fft=1024,
        hop_length=hop_length, win_length=win_length)
    # subtract mean from mfcc --> normalize mfcc
    mfcc = mfcc - np.mean(mfcc, axis=1).reshape((-1, 1))
    # delta feature 1st order and 2nd order
    delta1 = librosa.feature.delta(mfcc, order=1)
    delta2 = librosa.feature.delta(mfcc, order=2)
    # X is 36 x T
    X = np.concatenate([mfcc, delta1, delta2], axis=0) 
    # return T x 36 (transpose of X)
    return X.T  # hmmlearn use T x N matrix

In [2]:
import os
import pickle

import hmmlearn.hmm as hmm
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

In [3]:
class_names = ['XemNhietDo','XemGio','XemNgay','XemThoiTiet','BatNhac','TatNhac','BatQuat','TatQuat','BatDen','TatDen']
states = [10, 7, 8, 11, 7, 7, 7, 7, 7,7]
dataset_path = 'datasets/train_audio'

In [4]:
X = {'train': {}, 'test': {}}
y = {'train': {}, 'test': {}}

model = {}
model_path = 'model'

In [5]:
length = 0
for cname in class_names:
    length += len(os.listdir(f"{dataset_path}/{cname}"))
print('Total samples:', length)

all_data = {}
all_labels = {}
for cname in class_names:
    file_paths = [os.path.join(dataset_path, cname, i) for i in os.listdir(
        os.path.join(dataset_path, cname)) if i.endswith('.wav')]
    data = [get_mfcc(file_path) for file_path in file_paths]
    all_data[cname] = data
    all_labels[cname] = [class_names.index(cname) for _ in range(len(file_paths))]

for cname in class_names:
    x_train, x_test, y_train, y_test = train_test_split(
        all_data[cname], all_labels[cname],
        test_size=0.33,
        random_state=42
    )

    X['train'][cname] = x_train
    X['test'][cname] = x_test
    y['test'][cname] = y_test

Total samples: 540


In [6]:
total_train = 0
total_test = 0
for cname in class_names:
    train_count = len(X['train'][cname])
    test_count = len(X['test'][cname])
    print(cname, 'train:', train_count, '| test:', test_count)
    total_train += train_count
    total_test += test_count
print('train samples:', total_train)
print('test samples', total_test)

xemnhietdo train: 40 | test: 20
xemgio train: 40 | test: 20
xemngay train: 40 | test: 20
xemthoitiet train: 40 | test: 20
tat train: 40 | test: 20
bat train: 40 | test: 20
quat train: 40 | test: 20
nhac train: 40 | test: 20
den train: 40 | test: 20
train samples: 360
test samples 180


In [8]:
for idx, cname in enumerate(class_names):
    start_prob = np.full(states[idx], 0.0)
    start_prob[0] = 1.0
    trans_matrix = np.full((states[idx], states[idx]), 0.0)
    p = 0.5
    np.fill_diagonal(trans_matrix, p)
    np.fill_diagonal(trans_matrix[0:, 1:], 1 - p)
    trans_matrix[-1, -1] = 1.0
    # trans matrix
    print(cname)

    model[cname] = hmm.GaussianHMM(
        n_components=states[idx],
        verbose=True,
        n_iter=300,
        startprob_prior=start_prob,
        transmat_prior=trans_matrix,
        params='stmc',
        init_params='mc',
        random_state=42
    )
    model[cname].fit(X=np.vstack(X['train'][cname]),
                            lengths=[x.shape[0] for x in X['train'][cname]])

xemnhietdo


         1     -512141.5479             +nan
         2     -478753.8568      +33387.6911
         3     -469056.5194       +9697.3374
         4     -466802.3667       +2254.1526
         5     -466134.3788        +667.9879
         6     -465780.6062        +353.7726
         7     -465406.8483        +373.7578
         8     -465114.7306        +292.1177
         9     -464982.8753        +131.8554
        10     -464920.3340         +62.5413
        11     -464840.8544         +79.4797
        12     -464738.9708        +101.8836
        13     -464630.2908        +108.6799
        14     -464533.3330         +96.9578
        15     -464425.0139        +108.3191
        16     -464391.4762         +33.5377
        17     -464368.0806         +23.3956
        18     -464359.0348          +9.0458
        19     -464353.2426          +5.7922
        20     -464348.9050          +4.3376
        21     -464344.3653          +4.5397
        22     -464341.2354          +3.1299
        23

xemgio


         1     -314934.5055             +nan
         2     -294993.4506      +19941.0549
         3     -292450.3865       +2543.0641
         4     -291734.8741        +715.5125
         5     -291253.7239        +481.1501
         6     -290892.4543        +361.2696
         7     -290524.0845        +368.3698
         8     -290322.7137        +201.3708
         9     -290277.4748         +45.2389
        10     -290247.5850         +29.8898
        11     -290217.8751         +29.7099
        12     -290169.8261         +48.0490
        13     -290135.4880         +34.3380
        14     -290103.0854         +32.4026
        15     -290071.9748         +31.1106
        16     -290046.0230         +25.9518
        17     -290034.5364         +11.4867
        18     -290026.8930          +7.6434
        19     -290021.5695          +5.3235
        20     -290018.0958          +3.4736
        21     -290015.0971          +2.9987
        22     -290011.6278          +3.4694
        23

xemngay


         1     -321352.3315             +nan
         2     -296258.3936      +25093.9378
         3     -294058.8171       +2199.5765
         4     -293897.6825        +161.1346
         5     -293813.8510         +83.8315
         6     -293740.2870         +73.5640
         7     -293713.1655         +27.1215
         8     -293703.0950         +10.0705
         9     -293696.4410          +6.6541
        10     -293691.9961          +4.4449
        11     -293689.1679          +2.8282
        12     -293687.2244          +1.9435
        13     -293685.8972          +1.3272
        14     -293684.9916          +0.9057
        15     -293684.0209          +0.9707
        16     -293680.1180          +3.9029
        17     -293674.7591          +5.3589
        18     -293673.9369          +0.8223
        19     -293673.6287          +0.3082
        20     -293673.4444          +0.1843
        21     -293673.3209          +0.1234
        22     -293673.2251          +0.0958
        23

xemthoitiet


         1     -238436.0501             +nan
         2     -216596.6727      +21839.3774
         3     -215476.0432       +1120.6294
         4     -215271.0897        +204.9535
         5     -215161.0704        +110.0193
         6     -215067.1451         +93.9253
         7     -215016.0654         +51.0796
         8     -214982.2677         +33.7978
         9     -214942.9724         +39.2952
        10     -214893.4743         +49.4982
        11     -214843.0455         +50.4288
        12     -214778.8859         +64.1596
        13     -214717.6740         +61.2120
        14     -214637.8165         +79.8574
        15     -214563.1479         +74.6687
        16     -214493.8000         +69.3479
        17     -214435.2349         +58.5651
        18     -214395.4972         +39.7377
        19     -214372.1309         +23.3663
        20     -214364.0602          +8.0707
        21     -214362.0705          +1.9897
        22     -214361.2706          +0.7999
        23

tat


         1     -268405.6998             +nan
         2     -247601.8019      +20803.8979
         3     -246299.5946       +1302.2073
         4     -246123.0209        +176.5736
         5     -246085.2429         +37.7780
         6     -246073.2423         +12.0006
         7     -246069.2326          +4.0097
         8     -246067.4170          +1.8156
         9     -246066.4508          +0.9662
        10     -246065.8914          +0.5595
        11     -246065.5130          +0.3784
        12     -246065.2057          +0.3073
        13     -246064.9285          +0.2773
        14     -246064.6550          +0.2735
        15     -246064.3912          +0.2638
        16     -246064.2455          +0.1457
        17     -246064.2409          +0.0046


bat


         1     -228104.0812             +nan
         2     -211786.4757      +16317.6055
         3     -210493.2828       +1293.1928
         4     -210311.5840        +181.6988
         5     -210231.2331         +80.3510
         6     -210154.2156         +77.0175
         7     -210056.4288         +97.7869
         8     -209900.3179        +156.1108
         9     -209728.4916        +171.8263
        10     -209610.2231        +118.2686
        11     -209515.2601         +94.9629
        12     -209421.0349         +94.2252
        13     -209351.9920         +69.0429
        14     -209308.3644         +43.6275
        15     -209288.6179         +19.7465
        16     -209279.9847          +8.6332
        17     -209274.6553          +5.3294
        18     -209270.1857          +4.4696
        19     -209260.3851          +9.8006
        20     -209250.1406         +10.2446
        21     -209233.2692         +16.8714
        22     -209211.3133         +21.9559
        23

quat


         1     -219334.7229             +nan
         2     -206069.4698      +13265.2531
         3     -204329.0837       +1740.3861
         4     -203867.4785        +461.6052
         5     -203730.9346        +136.5438
         6     -203678.7465         +52.1881
         7     -203675.5904          +3.1561
         8     -203677.1347          -1.5443


nhac


         1     -242561.1456             +nan
         2     -224418.2590      +18142.8866
         3     -224160.6399        +257.6191
         4     -224133.5444         +27.0956
         5     -224124.4076          +9.1368
         6     -224119.6653          +4.7422
         7     -224117.8173          +1.8481
         8     -224116.9986          +0.8187
         9     -224116.4754          +0.5232
        10     -224116.0582          +0.4172
        11     -224115.6949          +0.3633
        12     -224115.3628          +0.3322
        13     -224115.0529          +0.3099
        14     -224114.7711          +0.2818
        15     -224114.5336          +0.2374
        16     -224114.3554          +0.1782
        17     -224114.2377          +0.1177
        18     -224114.1681          +0.0696
        19     -224114.1298          +0.0384
        20     -224114.1092          +0.0206
        21     -224114.0980          +0.0112
        22     -224114.0918          +0.0062


den


         1     -250494.0297             +nan
         2     -233943.9056      +16550.1241
         3     -233083.9210        +859.9846
         4     -232801.4280        +282.4931
         5     -232529.3918        +272.0361
         6     -232185.3278        +344.0640
         7     -231827.3409        +357.9870
         8     -231584.1512        +243.1897
         9     -231414.1675        +169.9837
        10     -231256.9736        +157.1939
        11     -231123.1024        +133.8712
        12     -231069.0133         +54.0892
        13     -231044.1770         +24.8362
        14     -231027.4959         +16.6811
        15     -231011.8420         +15.6539
        16     -230997.3857         +14.4563
        17     -230982.5759         +14.8098
        18     -230963.4324         +19.1435
        19     -230946.9602         +16.4722
        20     -230930.8505         +16.1097
        21     -230911.1240         +19.7265
        22     -230884.5829         +26.5412
        23

In [9]:
for cname in class_names:
    name = f'{model_path}/model_{cname}.pkl'
    with open(name, 'wb') as file:
        pickle.dump(model[cname], file)

In [10]:
y_true = []
y_pred = []
for cname in class_names:
    for mfcc, target in zip(X['test'][cname], y['test'][cname]):
        scores = [model[cname].score(mfcc) for cname in class_names]
        pred = np.argmax(scores)
        y_pred.append(pred)
        y_true.append(target)
    print(f'{cname}:', (np.array(y_true) == np.array(y_pred)).sum() / len(y_true))
print('======')
print('Confusion matrix:')
print(confusion_matrix(y_true, y_pred))

xemnhietdo: 1.0
xemgio: 0.975
xemngay: 0.9833333333333333
xemthoitiet: 0.9875
tat: 0.99
bat: 0.9916666666666667
quat: 0.9928571428571429
nhac: 0.9875
den: 0.9888888888888889
Confusion matrix:
[[20  0  0  0  0  0  0  0  0]
 [ 1 19  0  0  0  0  0  0  0]
 [ 0  0 20  0  0  0  0  0  0]
 [ 0  0  0 20  0  0  0  0  0]
 [ 0  0  0  0 20  0  0  0  0]
 [ 0  0  0  0  0 20  0  0  0]
 [ 0  0  0  0  0  0 20  0  0]
 [ 0  0  0  0  0  0  0 19  1]
 [ 0  0  0  0  0  0  0  0 20]]


Train with full dataset

In [22]:
finalX = {}
for cname in class_names:
    finalX[cname] = X['test'][cname] + X['train'][cname]
    print(cname,len(finalX[cname]))

xemnhietdo 60
xemgio 60
xemngay 60
xemgio 60
tat 60
bat 60
quat 60
nhac 60
den 60


In [None]:
for idx, cname in enumerate(class_names):
    start_prob = np.full(states, 0.0)
    trans_matrix = np.full((states, states), 0.0)
    p = 0.5
    np.fill_diagonal(trans_matrix, p)
    np.fill_diagonal(trans_matrix[0:, 1:], 1 - p)
    trans_matrix[-1, -1] = 1.0

    # trans matrix
    print(cname)
    # print(trans_matrix)

    model[cname] = hmm.GaussianHMM(
        n_components=states,
        verbose=True,
        n_iter=300,
        startprob_prior=start_prob,
        transmat_prior=trans_matrix,
        params='stmc',
        init_params='mc',
        random_state=42
    )
    model[cname].fit(X=np.vstack(finalX[cname]),
        lengths=[x.shape[0] for x in finalX[cname]])

In [24]:
for cname in class_names:
    name = f'{model_path}/model_{cname}.pkl'
    with open(name, 'wb') as file:
        pickle.dump(model[cname], file)

Test with new random wav file

In [17]:
test_file_name = 'tat1.wav'
test_file_path = 'datasets/random_test_audio/'
test_mfcc = get_mfcc(test_file_path+test_file_name)
scores = [model[cname].score(test_mfcc) for cname in class_names]

print("Input:",test_file_name)
print("score:",scores)
print("Output:",class_names[np.argmax(scores)])

Input: tat1.wav
score: [-3156.6247160914318, -3266.174281367263, -4004.4029607583666, -3598.9693971457077, -3699.11411846695, -3806.9000670368246, -4542.294908987875, -4975.758346145382, -5154.21512666474]
Output: xemnhietdo


In [92]:
model['xemgio'].transmat_

array([[0.91020378, 0.        , 0.        , 0.        , 0.08500456,
        0.        , 0.00479167],
       [0.        , 1.        , 0.        , 0.        , 0.        ,
        0.        , 0.        ],
       [0.        , 0.034725  , 0.89493378, 0.        , 0.        ,
        0.        , 0.07034123],
       [0.        , 0.        , 0.        , 0.98654232, 0.        ,
        0.01345768, 0.        ],
       [0.04146354, 0.        , 0.15466062, 0.        , 0.80387584,
        0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.16248435,
        0.83751565, 0.        ],
       [0.03870414, 0.        , 0.        , 0.03763044, 0.        ,
        0.        , 0.92366542]])

In [91]:
model['xemgio'].startprob_

array([4.19839370e-117, 0.00000000e+000, 0.00000000e+000, 7.89717484e-002,
       0.00000000e+000, 9.21028252e-001, 0.00000000e+000])

In [16]:
import librosa
y, sr = librosa.load('datasets/train_audio/bat/Open1.wav')  # read .wav file
print(sr)

22050
