In [1]:
import os
import pandas as pd
import librosa
from tqdm import tqdm
import numpy as np
from sklearn.model_selection import train_test_split
from keras.models import Sequential
from keras.layers import Conv2D, LSTM, Dense, Dropout, MaxPooling2D, Flatten, TimeDistributed
from keras.utils import to_categorical
from sklearn.preprocessing import LabelEncoder
import pickle

In [2]:
DIRECTORY = './cv-corpus-7.0-singleword/fr'

In [3]:
# Chargement du fichier tsv
data = pd.read_csv(os.path.join(DIRECTORY, 'validated.tsv'), sep='\t')
data

Unnamed: 0,client_id,path,sentence,up_votes,down_votes,age,gender,accent,locale,segment
0,02ccda0cd258d89228fc66cb072eb9bafbb1ae9f39dc8b...,common_voice_fr_23890678.mp3,huit,2,0,,,,fr,Benchmark
1,05a87054181791477a299a08fc35a6ff0c53250cae313e...,common_voice_fr_22108074.mp3,zéro,2,0,,,,fr,Benchmark
2,06c9c9e703dfa759edf4836936b42a07afd1021cedb06c...,common_voice_fr_22098482.mp3,trois,2,0,,,,fr,Benchmark
3,07a7db773acd156dd0b7fdc32f6b5eda9b32ffa1b3aee7...,common_voice_fr_21955578.mp3,quatre,6,3,,,,fr,Benchmark
4,0eb85c7dcb9b7ca2caec05a0dbbf6ee983cfab19164dac...,common_voice_fr_22157149.mp3,Firefox,3,1,,,,fr,Benchmark
...,...,...,...,...,...,...,...,...,...,...
20012,ffd847388e93bcd91855b2a4de02c87c29ed9df053da9c...,common_voice_fr_21954620.mp3,trois,3,0,,,,fr,Benchmark
20013,ffd847388e93bcd91855b2a4de02c87c29ed9df053da9c...,common_voice_fr_21954648.mp3,sept,3,0,,,,fr,Benchmark
20014,ffd847388e93bcd91855b2a4de02c87c29ed9df053da9c...,common_voice_fr_21954649.mp3,non,4,0,,,,fr,Benchmark
20015,ffd847388e93bcd91855b2a4de02c87c29ed9df053da9c...,common_voice_fr_21954650.mp3,zéro,3,0,,,,fr,Benchmark


In [4]:
# Dictionnaire pour stocker les features
feature_dict = {}

# Dossier contenant les fichiers audio
audio_dir = os.path.join(DIRECTORY, 'clips')

# Variable pour stocker la taille max
max_pad_len = 0

#Variable permettant de choisir si on utilise feature.pkl suivant True ou False
overwrite = False

# Vérifie si le fichier de features existe déjà
if os.path.exists('features.pkl') and overwrite == False:
    with open('features.pkl', 'rb') as f:
        feature_dict, max_pad_len = pickle.load(f)
else:
    # Dictionnaire pour stocker les features brutes
    raw_feature_dict = {}

    for index, row in tqdm(data.iterrows(), total=data.shape[0]):
        file_path = os.path.join(audio_dir, row['path'])
        audio, sr = librosa.load(file_path, sr=20050, duration=1)
        mfccs = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=40)
        
        # Actualiser la taille max
        if mfccs.shape[1] > max_pad_len:
            max_pad_len = mfccs.shape[1]

        raw_feature_dict[row['path']] = mfccs

    # Maintenant que nous avons max_pad_len, nous pouvons pad les features
    for path, mfccs in tqdm(raw_feature_dict.items(), total=len(raw_feature_dict)):
        pad_width = max_pad_len - mfccs.shape[1]
        mfccs = np.pad(mfccs, pad_width=((0, 0), (0, pad_width)), mode='constant')
        feature_dict[path] = mfccs.tolist() # stocke la matrice MFCC comme une liste à une dimension

    # Sauvegarde des features extraites dans un fichier pickle
    with open('features.pkl', 'wb') as f:
        pickle.dump((feature_dict, max_pad_len), f)


In [5]:
# Création du DataFrame pour le modèle
features_df = pd.DataFrame.from_dict(feature_dict, orient='index')
features_df.reset_index(drop=True, inplace=True)
features_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,30,31,32,33,34,35,36,37,38,39
0,"[-710.1168823242188, -710.1168823242188, -710....","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10421492...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10419186...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10415215...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10409770...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10402663...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10394085...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10383846...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10372328...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10358773...",...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.09723954...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.09677147...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.09628444...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.09579259...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.09528191...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.09475601...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.09421670...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.09366235...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.09309474...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.09251226..."
1,"[-1131.370849609375, -1122.0513916015625, -906...","[0.0, 7.610922813415527, 51.191123962402344, 5...","[0.0, -0.25896137952804565, 7.519996643066406,...","[0.0, -2.6510984897613525, 0.8749334216117859,...","[0.0, -0.09222883731126785, 1.0024974346160889...","[0.0, 2.927912712097168, 8.831277847290039, 10...","[0.0, 3.1054673194885254, 12.384306907653809, ...","[0.0, -1.2349419593811035, -8.552617073059082,...","[0.0, -2.6688475608825684, -9.47964859008789, ...","[0.0, 0.035477131605148315, -5.420425891876221...",...,"[0.0, -0.7341393232345581, -8.356950759887695,...","[0.0, -1.9580577611923218, -8.38644790649414, ...","[0.0, -1.1625733375549316, -6.293308258056641,...","[0.0, 0.9051556587219238, -6.946290016174316, ...","[0.0, 2.5427331924438477, -6.7146759033203125,...","[0.0, 4.198636054992676, 9.944777488708496, 6....","[0.0, 1.9960803985595703, -1.2293176651000977,...","[0.0, -0.6803606748580933, -3.5222692489624023...","[0.0, 0.06339943408966064, 2.782992362976074, ...","[0.0, 1.9178037643432617, 9.143674850463867, 6..."
2,"[-580.9547119140625, -502.08929443359375, -461...","[13.0746431350708, 53.5999870300293, 55.766273...","[8.943835258483887, 10.607219696044922, 6.3088...","[5.820484638214111, 1.854172945022583, 0.15633...","[5.725844383239746, 6.194729804992676, 7.47557...","[8.094457626342773, 10.655113220214844, 6.2078...","[10.422487258911133, 22.006254196166992, 11.08...","[10.410669326782227, 20.348894119262695, 15.56...","[7.969161033630371, 7.606284141540527, 4.73671...","[4.936771869659424, 10.347089767456055, 9.8350...",...,"[-0.20636723935604095, 6.80551815032959, 2.608...","[-0.4183443784713745, 0.2701655626296997, 2.27...","[-0.5794389843940735, -4.851261138916016, -4.9...","[-0.5164411067962646, -6.492099285125732, -5.3...","[-0.48947906494140625, -5.354612827301025, -6....","[-0.9195373058319092, -2.4186058044433594, -0....","[-1.9778814315795898, -2.9119930267333984, 0.0...","[-3.055476188659668, -3.901242256164551, 2.902...","[-3.088244915008545, -8.140846252441406, -4.03...","[-1.8577059507369995, -3.098597526550293, -0.9..."
3,"[-738.541259765625, -728.8106079101562, -513.5...","[0.0, 0.9618430733680725, -78.18080139160156, ...","[0.0, 6.76523494720459, 4.092826843261719, 4.8...","[0.0, 5.984655380249023, 58.10172653198242, 57...","[0.0, -5.109527111053467, -64.64938354492188, ...","[0.0, 5.047095775604248, -6.593890190124512, -...","[0.0, -8.038755416870117, 9.400949478149414, 1...","[0.0, -4.799571990966797, 18.501617431640625, ...","[0.0, -2.615769386291504, -20.82460594177246, ...","[0.0, -12.030517578125, 6.2764153480529785, 6....",...,"[0.0, 2.1856234073638916, -12.69400405883789, ...","[0.0, -0.1422356367111206, 7.330146789550781, ...","[0.0, -0.611248254776001, -1.9935173988342285,...","[0.0, 3.816716194152832, -0.4493294954299927, ...","[0.0, -1.6194462776184082, -3.6950454711914062...","[0.0, 2.600390911102295, 8.975187301635742, 4....","[0.0, 0.7847705483436584, 2.958645820617676, 0...","[0.0, -1.1721937656402588, -7.572305679321289,...","[0.0, 2.395240545272827, 1.259909749031067, 1....","[0.0, -2.0195705890655518, -3.73482608795166, ..."
4,"[-1131.370849609375, -1131.370849609375, -1131...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
20012,"[-839.5950927734375, -839.394287109375, -836.6...","[0.0, 0.18814148008823395, 2.934291124343872, ...","[0.0, 0.015593774616718292, 0.9309491515159607...","[0.0, -0.03749223053455353, 0.584429144859314,...","[0.0, 0.05382423847913742, 1.3418383598327637,...","[0.0, 0.11569546163082123, 1.0098061561584473,...","[0.0, 0.00814983993768692, -0.5028185248374939...","[0.0, -0.18588508665561676, -1.314603567123413...","[0.0, -0.2622777223587036, -0.8644583225250244...","[0.0, -0.15627549588680267, -0.525296390056610...",...,"[0.0, -0.07510697096586227, 1.0791258811950684...","[0.0, 0.05139052867889404, 1.0217194557189941,...","[0.0, 0.012003745883703232, 0.4734252095222473...","[0.0, -0.09021557867527008, 0.5732320547103882...","[0.0, -0.06537541002035141, 0.9493306875228882...","[0.0, 0.10949721932411194, 0.8259821534156799,...","[0.0, 0.25439298152923584, 0.6168226003646851,...","[0.0, 0.2156600058078766, 0.9091780185699463, ...","[0.0, 0.06240655854344368, 1.003467082977295, ...","[0.0, -0.005502490326762199, 0.137090682983398..."
20013,"[-892.3887939453125, -869.9851684570312, -833....","[0.0, 15.186519622802734, 30.428752899169922, ...","[0.0, -2.473966598510742, -5.777853965759277, ...","[0.0, 2.9573512077331543, 5.990267276763916, 9...","[0.0, 9.471694946289062, 7.374838352203369, 0....","[0.0, 3.2093563079833984, 5.543201446533203, 5...","[0.0, -3.1419594287872314, -5.7005791664123535...","[0.0, 2.1100218296051025, -1.411169409751892, ...","[0.0, 6.436252593994141, 5.669757843017578, -0...","[0.0, -1.9639419317245483, -0.1664968132972717...",...,"[0.0, 2.388721466064453, 5.367902755737305, 5....","[0.0, 5.653134346008301, 6.813438415527344, 4....","[0.0, 2.83657169342041, 6.152908802032471, 3.2...","[0.0, -1.4275798797607422, 0.7191084027290344,...","[0.0, 1.2065086364746094, -1.5860974788665771,...","[0.0, 3.6196365356445312, -0.9564111232757568,...","[0.0, 1.789374828338623, 1.0633172988891602, -...","[0.0, 2.2305006980895996, 1.9038183689117432, ...","[0.0, 2.072971820831299, -0.9768301248550415, ...","[0.0, 0.27045372128486633, -1.331000566482544,..."
20014,"[-584.4408569335938, -584.4408569335938, -584....","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
20015,"[-597.8016967773438, -597.8016967773438, -597....","[0.0, 0.0, -0.5708730220794678, -110.319458007...","[0.0, 0.0, 0.4095657765865326, -6.932246208190...","[0.0, 0.0, -0.182427316904068, 67.215652465820...","[0.0, 0.0, -0.060794152319431305, -52.35047149...","[0.0, 0.0, 0.2715596854686737, 25.463989257812...","[0.0, 0.0, -0.41484248638153076, -22.958286285...","[0.0, 0.0, 0.4760551452636719, 26.057525634765...","[0.0, 0.0, -0.4609253406524658, -18.1481437683...","[0.0, 0.0, 0.38924241065979004, 8.276302337646...",...,"[0.0, 0.0, 0.06721293926239014, -6.10905027389...","[0.0, 0.0, -0.1641162782907486, -0.31138157844...","[0.0, 0.0, 0.2307223677635193, 3.2726893424987...","[0.0, 0.0, -0.2702223062515259, 0.462925195693...","[0.0, 0.0, 0.2872284948825836, -1.682683944702...","[0.0, 0.0, -0.2834075093269348, -2.79690742492...","[0.0, 0.0, 0.2561028301715851, 4.1516103744506...","[0.0, 0.0, -0.2005692422389984, 0.812142372131...","[0.0, 0.0, 0.11459233611822128, -3.07168006896...","[0.0, 0.0, -0.0031815320253372192, -1.75058865..."


In [6]:
# Encodage des labels
le = LabelEncoder()
y = to_categorical(le.fit_transform(data['sentence']))
y

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 1.],
       [0., 0., 0., ..., 1., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 1.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [7]:
# Crée un dictionnaire pour stocker les labels et leur encodage one-hot correspondant
label_encoding = {}

# Obtenir l'ordre des classes du label_encoder
classes = list(le.classes_)

# Parcourir toutes les classes
for i in range(len(classes)):
    # Crée un vecteur one-hot
    one_hot_vector = np.zeros(len(classes), dtype=int)
    one_hot_vector[i] = 1

    # Stocke le label et l'encodage one-hot correspondant dans le dictionnaire
    label_encoding[classes[i]] = one_hot_vector

# Afficher les labels et leur encodage one-hot correspondant
for label, encoding in label_encoding.items():
    print(f'Label: {label} \t One-Hot Encoding: {encoding}')


Label: Firefox 	 One-Hot Encoding: [1 0 0 0 0 0 0 0 0 0 0 0 0 0]
Label: Hey 	 One-Hot Encoding: [0 1 0 0 0 0 0 0 0 0 0 0 0 0]
Label: cinq 	 One-Hot Encoding: [0 0 1 0 0 0 0 0 0 0 0 0 0 0]
Label: deux 	 One-Hot Encoding: [0 0 0 1 0 0 0 0 0 0 0 0 0 0]
Label: huit 	 One-Hot Encoding: [0 0 0 0 1 0 0 0 0 0 0 0 0 0]
Label: neuf 	 One-Hot Encoding: [0 0 0 0 0 1 0 0 0 0 0 0 0 0]
Label: non 	 One-Hot Encoding: [0 0 0 0 0 0 1 0 0 0 0 0 0 0]
Label: oui 	 One-Hot Encoding: [0 0 0 0 0 0 0 1 0 0 0 0 0 0]
Label: quatre 	 One-Hot Encoding: [0 0 0 0 0 0 0 0 1 0 0 0 0 0]
Label: sept 	 One-Hot Encoding: [0 0 0 0 0 0 0 0 0 1 0 0 0 0]
Label: six 	 One-Hot Encoding: [0 0 0 0 0 0 0 0 0 0 1 0 0 0]
Label: trois 	 One-Hot Encoding: [0 0 0 0 0 0 0 0 0 0 0 1 0 0]
Label: un 	 One-Hot Encoding: [0 0 0 0 0 0 0 0 0 0 0 0 1 0]
Label: zéro 	 One-Hot Encoding: [0 0 0 0 0 0 0 0 0 0 0 0 0 1]


In [8]:
# Création de X
X = np.array(features_df.values.tolist())
X = X.reshape(X.shape[0], 40, -1)
X.shape

(20017, 40, 40)

In [9]:
from sklearn.preprocessing import StandardScaler

# Reshape les données en 2D pour la normalisation
X_2D = X.reshape(-1, X.shape[-1])

# Normalisation des données
scaler = StandardScaler()
X_2D = scaler.fit_transform(X_2D)

# Remettre les données en 3D
X = X_2D.reshape(X.shape)

# Reshape pour CNN 2D
X = X[..., np.newaxis]


In [10]:
# Séparation des données en train et test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)


In [11]:
# Création du modèle
model = Sequential()

model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(X.shape[1], X.shape[2], X.shape[3])))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))

# Flatten output pour utiliser dans les couches LSTM
model.add(TimeDistributed(Flatten())) 

model.add(LSTM(100))
model.add(Dropout(0.5))
model.add(Dense(100, activation='relu'))
model.add(Dense(y.shape[1], activation='softmax'))

In [12]:
# Compilation du modèle
model.compile(loss='categorical_crossentropy', metrics=['accuracy'], optimizer='adam')


In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Liste pour stocker les valeurs des métriques pour chaque époque
accuracy_values = []
precision_values = []
recall_values = []
f1_values = []

# Entraînement du modèle avec boucle sur les époques
for epoch in range(1, 101):
    history = model.fit(X_train, y_train, epochs=1, batch_size=32, validation_data=(X_test, y_test), verbose=0)
    y_pred = model.predict(X_test)
    
    # Calcul des métriques pour cette époque
    accuracy = accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_pred, axis=1))
    precision = precision_score(np.argmax(y_test, axis=1), np.argmax(y_pred, axis=1), average='weighted')
    recall = recall_score(np.argmax(y_test, axis=1), np.argmax(y_pred, axis=1), average='weighted')
    f1 = f1_score(np.argmax(y_test, axis=1), np.argmax(y_pred, axis=1), average='weighted')
    
    # Stockage des valeurs des métriques
    accuracy_values.append(accuracy)
    precision_values.append(precision)
    recall_values.append(recall)
    f1_values.append(f1)
    
    print(f"Epoch {epoch}: Accuracy = {accuracy:.4f}, Precision = {precision:.4f}, Recall = {recall:.4f}, F1 = {f1:.4f}")

# Création du graphe
epochs = range(1, len(accuracy_values) + 1)
plt.figure(figsize=(8, 6))
plt.plot(epochs, accuracy_values, marker='o', label='Accuracy')
plt.plot(epochs, precision_values, marker='o', label='Precision')
plt.plot(epochs, recall_values, marker='o', label='Recall')
plt.plot(epochs, f1_values, marker='o', label='F1')
plt.title('Performances du modèle')
plt.xlabel('Époque')
plt.ylabel('Score')
plt.legend()
plt.show()



  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1: Accuracy = 0.1521, Precision = 0.1068, Recall = 0.1521, F1 = 0.1080
Epoch 2: Accuracy = 0.2155, Precision = 0.2355, Recall = 0.2155, F1 = 0.1927
Epoch 3: Accuracy = 0.2857, Precision = 0.2907, Recall = 0.2857, F1 = 0.2595
Epoch 4: Accuracy = 0.3384, Precision = 0.4279, Recall = 0.3384, F1 = 0.3247
Epoch 5: Accuracy = 0.4281, Precision = 0.4724, Recall = 0.4281, F1 = 0.4382
Epoch 6: Accuracy = 0.4655, Precision = 0.5721, Recall = 0.4655, F1 = 0.4760
Epoch 7: Accuracy = 0.4713, Precision = 0.5374, Recall = 0.4713, F1 = 0.4798
Epoch 8: Accuracy = 0.5160, Precision = 0.6000, Recall = 0.5160, F1 = 0.5343
Epoch 9: Accuracy = 0.5305, Precision = 0.6329, Recall = 0.5305, F1 = 0.5524
Epoch 10: Accuracy = 0.5220, Precision = 0.5899, Recall = 0.5220, F1 = 0.5348
Epoch 11: Accuracy = 0.5207, Precision = 0.6023, Recall = 0.5207, F1 = 0.5352
Epoch 12: Accuracy = 0.5455, Precision = 0.6129, Recall = 0.5455, F1 = 0.5579
Epoch 13: Accuracy = 0.5537, Precision = 0.6291, Recall = 0.5537, F1 = 0.

In [None]:
# Compilation du modèle
#model.compile(loss='categorical_crossentropy', metrics=['accuracy'], optimizer='adam')

# Entraînement du modèle
#model.fit(X_train, y_train, epochs=20, batch_size=32, validation_data=(X_test, y_test))

# Sauvegarde du modèle
#model.save('model_matthieu.h5')

In [None]:
#from sklearn.metrics import confusion_matrix
#import seaborn as sns
#import matplotlib.pyplot as plt

# Prédiction sur les données de test
#y_pred = model.predict(X_test)
#y_pred_classes = np.argmax(y_pred, axis=1)
#y_test_classes = np.argmax(y_test, axis=1)

# Calcul de la matrice de confusion
#confusion_mtx = confusion_matrix(y_test_classes, y_pred_classes)

# Affichage de la matrice de confusion
#plt.figure(figsize=(10, 8))
#sns.heatmap(confusion_mtx, annot=True, fmt='d', cmap='Blues')
#plt.xlabel('Prédiction')
#plt.ylabel('Vraie étiquette')
#plt.title('Matrice de confusion')
#plt.show()
