In [None]:
import numpy as np
import tensorflow as tf
import random
import librosa
import glob
import os
from multiprocessing import Process
from multiprocessing.pool import Pool
from multiprocessing import Manager
from multiprocessing import Process

class Model:

    def __init__(self, out_size = 10):
        self.data_path = data_path
        self.out_size = out_size

    def single_lstm(self, input_shape):
        model = tf.keras.models.Sequential()
        model.add(tf.keras.layers.Input(shape = input_shape))
        model.add(tf.keras.layers.LSTM(256))
        model.add(tf.keras.layers.Dense(64, activation = 'relu'))
        model.add(tf.keras.layers.Dense(self.out_size, activation = 'sigmoid'))
        model.compile(loss=tf.keras.losses.BinaryCrossentropy(),
                      optimizer="sgd",
                      metrics=[tf.keras.metrics.Recall(), tf.keras.metrics.Precision(), 'binary_accuracy'])
        print(model.summary())
        return model

class DataGenerator:
    
    def __init__(self, data_path, out_size, speech_len, validation_split = 0.1):
        self.data_path = data_path
        self.out_size = out_size
        self.speech_length = librosa.time_to_samples(speech_len)
        self.data_list = glob.glob(os.path.join(self.data_path,"CC_*"))
        random.shuffle(self.data_list)
        self.train = self.data_list[:int(len(self.data_list)*validation_split)]
        self.valid = self.data_list[int(len(self.data_list)*validation_split):]

    def empty_sequence(self, n_part):
        return [[0 for _ in range(self.speech_length)] for _ in range(n_part)]
        
    def train_generator(self, batch_size):
        manager = Manager()
        while True:
            process = []
            self.out = manager.list()
            for _ in range(batch_size):
                p = Process(target = self.load_files, args = (self.train,))
                p.start()
                process.append(p)
            for p in process:
                p.join()
            out_data = []
            out_labels = []
            for data, label in self.out:
                if data is not None and label is not None:
                    out_data.append(data)
                    out_labels.append(label)
            yield np.array(out_data), np.array(out_labels)
        
    def valid_generator(self, batch_size):
        manager = Manager()
        while True:
            process = []
            self.out = manager.list()
            for _ in range(batch_size):
                p = Process(target = self.load_files, args = (self.valid,))
                p.start()
                process.append(p)
            for p in process:
                p.join()
            out_data = []
            out_labels = []
            for data, label in self.out:
                if data is not None and label is not None:
                    out_data.append(data)
                    out_labels.append(label)
            yield np.array(out_data), np.array(out_labels)
            
    def load_files(self, source_folder):
        np.random.seed()
        data_folder = np.random.choice(source_folder)
        speaker = np.random.choice(os.listdir(data_folder))
        folder = os.path.join(data_folder, speaker.decode("utf8"))
        files = sorted(os.listdir(folder))
        if len(files)<=self.out_size:
            audio = self.empty_sequence(self.out_size-len(files))
            label = [True for _ in range(self.out_size-len(files))]
            start_choice = 0
        else:
            start_choice = np.random.choice(range(len(files)-self.out_size))
            audio = []
            label = []
        for file in files[start_choice:start_choice+self.out_size]:
            label.append(str(file).endswith("0.wav"))
            wave, sr = librosa.load(os.path.join(folder,file))                        
            if len(wave)<=self.speech_length:
                pad = [0 for _ in range(self.speech_length - len(wave))]
                wave = pad+list(wave)
                audio.append(wave)
            else:
                start_ind = np.random.choice(range(len(wave)-self.speech_length))
                audio.append(wave[start_ind:start_ind+self.speech_length])
        self.out.append([audio,label])

class LossAndErrorPrintingCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        with open('log_single_LSTM_model.txt','a',encoding = 'utf8') as fw:
            fw.write("For epoch {}".format(epoch))
            fw.write("\n")
            fw.write("Loss is       {:7.2f}, val_loss is      {:7.2f}.".format(logs["loss"],logs["val_loss"]))
            fw.write("\n")
            fw.write("Recall is     {:7.2f}, Presicion is     {:7.2f}.".format(logs["recall"],logs["precision"]))
            fw.write("\n")
            fw.write("Val_Recall is {:7.2f}, Val_Presicion is {:7.2f}.".format(logs["val_recall"],logs["val_precision"]))
            fw.write("\n")
            fw.write("="*100)
            
    def on_epoch_begin(self, epoch, logs=None):
        with open('log_single_LSTM_model.txt','a',encoding = 'utf8') as fw:
            fw.write("\n")
            fw.write("Starting epoch {}".format(epoch))
            fw.write("\n")        

In [None]:
if __name__ == "__main__":
    data_path = '/home/ubuntu/ProjectVietVu/splited_data'
    out_size = 10
    speech_len = 1
    data_gen = DataGenerator(data_path, out_size, speech_len)
    train_gen = data_gen.train_generator(batch_size = 16)
    valid_gen = data_gen.valid_generator(batch_size = 8)
    X, y = next(train_gen)
    print(X.shape, y.shape)
    base_model = Model()
    single_model = base_model.single_lstm((X.shape[1],X.shape[2],))
    checkpoint = tf.keras.callbacks.ModelCheckpoint("Single_LSTM_model_checkpoint.h5", monitor='val_loss', verbose=1, save_best_only=True, mode='min')
    history = single_model.fit(train_gen,
                                steps_per_epoch = 100,
                                epochs = 10,
                                verbose = 1,
                                shuffle = False,
                                validation_data = valid_gen,
                                validation_steps = 10,
                                callbacks = [checkpoint,LossAndErrorPrintingCallback()])

In [None]:
[len(i) for i in X[0]]

In [None]:
X.shape, y.shape