# TinyML for COVID-19 coughs detection

In [None]:
!pip -q install librosa==0.8.0

In [None]:
import os
import pandas as pd
import numpy as np
from time import time
import librosa, librosa.display
import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
# plt.switch_backend('agg')

import itertools
import scipy
from scipy import signal
from scipy.fftpack import dct
from scipy.stats import zscore

# from sklearn.preprocessing import LabelEncoder, scale, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
# from sklearn.decomposition import PCA

import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dense, Dropout, \
    Flatten, BatchNormalization, ReLU, DepthwiseConv2D, SeparableConv2D, AveragePooling2D
# from tensorflow.keras.layers import Convolution2D, Conv2D, MaxPooling2D, GlobalAveragePooling2D, UpSampling2D, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras import optimizers
from tensorflow.keras.utils import plot_model

rand_seed = 10
np.random.seed(rand_seed)
tf.random.set_seed(rand_seed)

In [None]:
class MFCC:

    def __init__(self, audio_filename, sample_rate=None, pre_emphasis_alpha=None,
                 frame_size=0.025, frame_stride=0.01, num_fft=512,
                 num_mel_filters=40, num_mfcc=12,
                 liftering=False, normalization=False):
        if sample_rate is not None:
            self.sample_rate = sample_rate
        else:
            self.sample_rate = 22050  # default sample rate of librosa
        
        self.audio, sr = librosa.load(audio_filename, self.sample_rate)
        assert sr == self.sample_rate
        
        self.alpha = pre_emphasis_alpha
        self.frame_size = frame_size
        self.frame_stride = frame_stride
        self.num_fft = num_fft
        self.num_mel_filters = num_mel_filters
        self.num_mfcc = num_mfcc
        self.liftering = liftering
        self.normalization = normalization
        
        # Keep the first 3.5 seconds
        # self.audio = self.audio[0:int(3.5 * self.sample_rate)]
    
    def emphasize(self):
        """
        Apply a pre-emphasis filter on the signal to amplify the high frequencies.
        A pre-emphasis filter is useful in several ways:
        (1) balance the frequency spectrum since high frequencies usually have 
            smaller magnitudes compared to lower frequencies,
        (2) avoid numerical problems during the Fourier transform operation and 
        (3) may also improve the Signal-to-Noise Ratio (SNR).

        The pre-emphasis filter can be applied to a signal x using the first order filter 
        in the following equation:
        x(t) = x(t) - alpha * x(t-1)
        """
        self.audio = np.append(self.audio[0], self.audio[1:] - pre_emphasis * self.audio[:-1])
    
    def get_frames():
        """
        Split the signal into short-time frames.
        The rationale behind this step is that frequencies in a signal change over time,
        so in most cases it doesn’t make sense to do the Fourier transform across the entire signal 
        in that we would lose the frequency contours of the signal over time.
        To avoid that, we can safely assume that frequencies in a signal are stationary 
        over a very short period of time.
        Therefore, by doing a Fourier transform over this short-time frame, we can obtain 
        a good approximation of the frequency contours of the signal by concatenating adjacent frames.

        Typical frame sizes in speech processing range from 20 ms to 40 ms with 
        50% (+/-10%) overlap between consecutive frames.
        Popular settings are 25 ms for the frame size, frame_size = 0.025 and 
        a 10 ms stride (15 ms overlap), frame_stride = 0.01.
        """
        # Convert from seconds to samples
        frame_length = self.frame_size * self.sample_rate
        frame_step = self.frame_stride * self.sample_rate
        
        signal_length = len(self.audio)
        frame_length = int(round(frame_length))
        frame_step = int(round(frame_step))
        # Make sure that we have at least 1 frame
        num_frames = int(np.ceil(float(np.abs(signal_length - frame_length)) / frame_step))

        pad_signal_length = num_frames * frame_step + frame_length
        zero_padding = np.zeros((pad_signal_length - signal_length))
        # Pad Signal to make sure that all frames have equal number of samples 
        # without truncating any samples from the original signal
        pad_signal = np.append(self.audio, zero_padding)

        indices = np.tile(np.arange(0, frame_length), (num_frames, 1)) + np.tile(np.arange(0, num_frames * frame_step, frame_step), (frame_length, 1)).T
        frames = pad_signal[indices.astype(np.int32, copy=False)]
    
    def mel_fileterbanks(self, pow_frames, normalization=False):
        low_freq_mel = 0
        high_freq_mel = (2595 * np.log10(1 + (self.sample_rate / 2) / 700))  # Convert Hz to Mel
        # Equally spaced points in Mel scale
        mel_points = np.linspace(low_freq_mel, high_freq_mel, self.num_mel_filters + 2)
        hz_points = (700 * (10**(mel_points / 2595) - 1))  # Convert Mel to Hz
        bin = np.floor((self.num_fft + 1) * hz_points / self.sample_rate)

        fbank = np.zeros((self.num_mel_filters, int(np.floor(self.num_fft / 2 + 1))))
        for m in range(1, nfilt + 1):
            f_m_minus = int(bin[m - 1])   # left
            f_m = int(bin[m])             # center
            f_m_plus = int(bin[m + 1])    # right

        for k in range(f_m_minus, f_m):
            fbank[m - 1, k] = (k - bin[m - 1]) / (bin[m] - bin[m - 1])
        for k in range(f_m, f_m_plus):
            fbank[m - 1, k] = (bin[m + 1] - k) / (bin[m + 1] - bin[m])

        filter_banks = np.dot(pow_frames, fbank.T)
        filter_banks = np.where(filter_banks == 0, np.finfo(float).eps, filter_banks) # Numerical Stability
        filter_banks = 20 * np.log10(filter_banks)  # dB

        if self.normalization:
            filter_banks -= (np.mean(filter_banks, axis=0) + 1e-8)

        return filter_banks
    
    def mfcc(self):
        # 1) Split the audio signal into short-time frames:
        frames = self.get_frames()
        
        # 2) Apply a window function such as the Hamming window to each frame.
        # A Hamming window has the following form:
        # w[n] = 0.54 - 0.46 * cos( (2* pi * n) / (N - 1) )
        frames *= numpy.hamming(frame_length)
        # Explicit Implementation:
        # frames *= 0.54 - 0.46 * numpy.cos((2 * numpy.pi * n) / (frame_length - 1)) 

        # 3) Compute the power spectrum
        # To achieve this goal we first need to compute the frequency spectrum, 
        # through an N-point FFT on each frame, which is also called 
        # Short-Time Fourier-Transform (STFT), where N is typically 256 or 512,
        # and then compute the power spectrum (periodogram) using the following equation:
        # P = (|FFT(x_i)|^2) / N, where x_i is the i-th frame of the audio signal x
        mag_frames = numpy.absolute(numpy.fft.rfft(frames, self.num_fft))  # Magnitude of the FFT
        pow_frames = ((1.0 / self.num_fft) * ((mag_frames) ** 2))  # Power Spectrum

        # 4) Process the frames power spectrum using the MEL Filterbanks
        mel_filtbanks_coeff = self.mel_fileterbanks(pow_frames)

        # 5) Apply Discrete Cosine Transform (DCT)
        # Usually filter bank coefficients computed in the previous step are highly correlated,
        # which could be problematic in some machine learning algorithms.
        # Therefore, we can apply Discrete Cosine Transform (DCT) to decorrelate 
        # the filter bank coefficients and yield a compressed representation of the filter banks.
        # Typically, for Automatic Speech Recognition (ASR), the resulting cepstral coefficients 2-13
        # are retained and the rest are discarded; num_mfcc = 12.
        # The reasons for discarding the other coefficients is that they represent fast changes 
        # in the filter bank coefficients and these fine details don’t contribute to ASR.
        mfcc = dct(mel_filtbanks_coeff, type=2, axis=1, norm='ortho')[:, 1 : (self.num_mfcc + 1)] # Keep 2-13

        # 6) Optionally apply sinusoidal liftering to the MFCCs:
        # It can be used to de-emphasize higher MFCCs which has been claimed to improve 
        # speech recognition in noisy signals.
        if self.liftering:
            (nframes, ncoeff) = mfcc.shape
            n = np.arange(ncoeff)
            lift = 1 + (cep_lifter / 2) * np.sin(np.pi * n / cep_lifter)
            mfcc *= lif
        
        # 7) Optionally apply Mean normalization:
        # Mean normalization can be used to balance the spectrum and improve the Signal-to-Noise (SNR),
        # by subtract the mean of each coefficient from all frames.
        if self.normalization:
            mfcc = zscore(mfcc, axis=0)
            # Equivalent to:
            # mean = np.mean(mfcc, axis=0) # + 1e-8
            # std_dev = np.std(mfcc, axis=0)
            # Two main ways to handle zero std_dev (when all samples are equal)
            # 1) add a small eps:
            # std_dev += 1e-8
            # 2) replace zero std_dev with 1
            # std_dev = [1. if std_i == 0. else std_i for std_i in std_dev]
            # mfcc = (mfcc - mean) / std_dev
        
        return mfcc

In [None]:
class AudioClassifier():
    def __init__(self, dataset_path="audio", labels_file='cough_dataset.csv',
                 model_name="tinyconv1", num_covid_audio=17, num_other_audio=30,
                 sample_rate=None, audio_max_samples=None, augment=True,
                 frame_length=None, hop_length=None, num_mfcc_frames=None, 
                 num_mfcc=20, num_fft=2048, num_mels=128, normalize=False, 
                 batch_size=4, lr=0.001, epochs=100, plot=False):
        self.target_names = ['not_covid', 'covid']
        self.dataset_path = dataset_path
        self.labels_file = labels_file
        self.num_covid = num_covid_audio
        self.num_notcovid = num_other_audio
        self.model_name = model_name
        
        if sample_rate is not None and sample_rate > 0:
            self.sample_rate = sample_rate
        else:
            self.sample_rate = 22050  # default sample rate of librosa: 22.05 kHz
        
        if audio_max_samples is not None and audio_max_samples < 0:
            # raise ValueError("Invalid value for audio_max_samples: {}".format(audio_max_samples))
            audio_max_samples = None
        self.audio_max_samples = audio_max_samples

        if augment not in [True, False]:
            raise ValueError("Invalid value for augment: {}".format(augment))
        self.augment = augment
        
        if frame_length is None and frame_length < 0:
            self.frame_length = num_fft   # default value for librosa
        else:
            self.frame_length = frame_length
        
        if hop_length is None and hop_length < 0:
            self.hop_length = num_fft // 4   # default value for librosa
        else:
            self.hop_length = hop_length
        
        if num_mfcc_frames is not None and num_mfcc_frames < 0:
            # raise ValueError("Invalid value for num_mfcc_frames: {}".format(num_mfcc_frames))
            num_mfcc_frames = None
        self.num_mfcc_frames = num_mfcc_frames
        
        self._check_pos_param(num_mfcc)
        self.num_mfcc = num_mfcc
        
        self._check_pos_param(num_fft)
        self.num_fft = num_fft
        
        self._check_pos_param(num_mels)
        self.num_mels = num_mels
        
        if normalize not in [True, False]:
            raise ValueError("Invalid value for normalize: {}".format(normalize))
        self.normalize = normalize

        self._check_pos_param(batch_size)
        self.batch_size = batch_size
        self._check_pos_param(lr)
        self.lr = lr
        self._check_pos_param(epochs)
        self.epochs = epochs
        if plot not in [True, False]:
            raise ValueError("Invalid value for plot: {}".format(plot))
        self.plot = plot

        """
        MFCC settings to try:
        1) https://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
           num_fft = 512, num_mels = 40, num_mfcc = 13 (2-12)
        2) http://practicalcryptography.com/miscellaneous/machine-learning/guide-mel-frequency-cepstral-coefficients-mfccs/:
           num_fft = 512, num_mels = 26, num_mfcc = 13 (2-12)
        3) Default settings of librosa:
           num_mfcc=20, num_fft=2048, num_mels=128
        """

        # if the dataset file does not exists, download it
        if self.dataset_path is None or not os.path.isdir(self.dataset_path) or \
                self.labels_file is None or not os.path.isfile(self.labels_file):
            print("Invalid dataset path. Dowloading the dataset ...")
            os.system("wget -O ./data.zip https://www.dropbox.com/sh/mjyrspykfx116lo/AADb3-7_SpUpF90LcPkbRV3Ga?dl=1")
            os.system("mkdir -p dataset && unzip data.zip -d dataset")

            self.dataset_path = os.path.join(os.getcwd(), "dataset/audio/")
            self.labels_file = os.path.join(os.getcwd(), "dataset/cough_dataset.csv")
        
        os.system("mkdir -p plots")
        os.system("mkdir -p mfcc")
    
    def _check_pos_param(self, param):
        if param is None or param < 0:
            raise ValueError("Invalid value for {}: {}".format(param.__name__, param))
    
    def _save_augmented(self, aug_audio, audio_name, aug_type):
        augmented_name = audio_name[:-4] + "_" + aug_type + ".wav"
        librosa.output.write_wav(augmented_name, aug_audio, self.sample_rate)

        return augmented_name
    
    def augment_audio(self, audio_names, audio_labels):
        augmented_audios = np.array([])
        augmented_labels = np.array([])

        for name, label in zip(audio_names, audio_labels):
            file_path = os.path.join(self.dataset_path, name)
            audio, sr = librosa.load(file_path, self.sample_rate)
            augmented_audios.append(name)
            # repeat the same label for all the augmented samples,
            # including also the label of the original sample (5 + 1 = 6 labels)
            augmented_labels = np.append(augmented_labels, [label] * 6)

            # 1) Noise Injection:
            # Add some random white noise using a noise factor of 0.005 to limit
            # the noise amplitude:
            noise = np.random.randn(len(audio))
            augmented_audio = audio + 0.005 * noise
            # Cast back to same data type
            augmented_audio = augmented_audio.astype(type(audio[0]))
            # augmented_name = name[:-4] + "_noise.wav"
            # librosa.output.write_wav(augmented_name, augmented_audio, sr)
            aug_name = self._save_augmented(augmented_audio, name, "noise")
            augmented_audios = np.append(augmented_audios, aug_name)

            # 2) Random shift:
            # Randomly shift audio to left/right with a random second. If shifting audio to left 
            # (fast forward) with x seconds, first x seconds will mark as 0 (i.e. silence).
            # If shifting audio to right (back forward) with x seconds, last x seconds will mark as 0.
            max_shift = 3  # max shift in seconds
            shift = np.random.randint(sr * shift_max)  # by default, use fast-forward shift
            # Uncommment the following line for back forward shift:
            # shift = -shift
            augmented_audio = np.roll(audio, shift)
            # Set to silence for heading/ tailing
            if shift > 0:
                augmented_audio[:shift] = 0
            else:
                augmented_audio[shift:] = 0
            aug_name = self._save_augmented(augmented_audio, name, "shift")
            augmented_audios = np.append(augmented_audios, aug_name)

            # 3) Change pitch randomly:
            # how many (fractional) steps to shift the pitch
            # A step is equal to a semitone if bins_per_octave is set to 12 (default)
            pitch_factor = 3
            augmented_audio = librosa.effects.pitch_shift(audio, sr, pitch_factor)
            aug_name = self._save_augmented(augmented_audio, name, "pitch")
            augmented_audios = np.append(augmented_audios, aug_name)

            # 4) Change speed randomly:
            # Time-stretch an audio series by a fixed rate
            # If speed_factor > 1, then the signal is sped up.
            # Otherwise, if speed_factor < 1, then the signal is slowed down.
            speed_factor = 1.5
            augmented_audio = librosa.effects.time_stretch(audio, speed_factor)
            aug_name = self._save_augmented(augmented_audio, name, "speed")
            augmented_audios = np.append(augmented_audios, aug_name)

            # 5) Normalize (min-max normalization)
            lower = np.min(np.abs(audio))
            augmented_audio = (data - lower) / (np.max(np.abs(data)) - lower)
            aug_name = self._save_augmented(augmented_audio, name, "speed")
            augmented_audios = np.append(augmented_audios, aug_name)
        
        # shuffle the list of augmented audio names and labels
        rand_idx = np.arange(len(augmented_audios))
        np.random.shuffle(rand_idx)
        augmented_audios = augmented_audios[rand_idx]
        augmented_labels = augmented_labels[rand_idx]

        return augmented_audios, augmented_labels


    def get_mfcc(self, file_name):
        try:
            """
            Load and preprocess the audio
            """
            audio, sr = librosa.load(file_name, self.sample_rate)
            assert sr == self.sample_rate
            length_s = audio.shape[0] / float(sr)

            # print("Loaded audio file: {}, num samples, length: {} s".format(file_name, len(audio), length_s))
            if self.plot:
                self.plot_audio(audio, file_name, length_s)

            # Remove vocals using foreground separation:
            # https://librosa.github.io/librosa_gallery/auto_examples/plot_vocal_separation.html
            # y_no_vocal = self.vocal_removal(audio, sample_rate)

            # Remove noise using median smoothing
            # y = self.reduce_noise_median(audio, sample_rate)

            # Only use audio above the human vocal range (85-255 Hz)
            # fmin = 260
            # fmax = 10000

            # Audio slices
            y = audio
            
            # Take the first self.audio_max_samples samples of the audio:
            y = y[:self.audio_max_samples]

            # Extract MFCC features
            """
            max_pad_length = 431
            n_mfcc = 120
            n_fft = 4096
            hop_length = 512
            n_mels = 512
            
            # mfccs = librosa.feature.mfcc(y=y, sr=sample_rate, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, fmin=fmin, fmax=fmax)
            mfccs = librosa.feature.mfcc(y=y, sr=sample_rate, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
            pad_width = max_pad_length-mfccs.shape[1]
            mfccs = np.pad(mfccs, pad_width=((0,0),(0,pad_width)), mode='constant')
            # print(mfccs.shape)
            # mfccs_scaled = np.mean(mfccs.T, axis=0)
            """
            mfccs = librosa.feature.mfcc(y=y, sr=self.sample_rate,
                                         n_mfcc=self.num_mfcc, 
                                         n_fft=self.num_fft,
                                         n_mels=self.num_mels,
                                         win_length=self.frame_length,
                                         hop_length=self.hop_length
                                         )
            if self.plot:
                self.plot_mfcc(file_name, mfccs, sr, time=length_s)

        except Exception as e:
            print("Error encountered while processing file: {}".format(e))
            return None, None, None 

        return mfccs, self.sample_rate, length_s
    
    def plot_audio(self, audio, name, time_length):
        fig = plt.figure(figsize=(10, 4))
        plt.title("Audio sample {}".format(name))
        plt.ylabel("Amplitude")
        plt.plot(np.linspace(0, time_length, len(audio)), audio)
        plt.show()
    
    def load_data(self, data_names, data_labels):
        mfcc_data = None
        labels = []

        for name, label in zip(data_names, data_labels):
            file_path = os.path.join(self.dataset_path, name)
            mfcc, sample_rate, time_len = self.get_mfcc(file_path)
            if mfcc is not None:
                label_str = "covid" if label == 1. else "not_covid"
                print("Loaded audio: {}, Label: {}".format(name, label_str))
                print("MFCC initial shape: {}".format(mfcc.shape))
                n_mfcc, n_frames = mfcc.shape

                if self.num_mfcc_frames is not None:
                    if n_frames >= self.num_mfcc_frames:
                        # take only the needed MFCC frames
                        mfcc = mfcc[:, :self.num_mfcc_frames]
                    else:
                        # padd with zero
                        pad_width = self.num_mfcc_frames - n_frames
                        mfcc = np.pad(mfcc, pad_width=((0, 0), (0, pad_width)), mode='constant')

                if self.normalize:
                    # mfcc = mfcc.T  # transpose mfcc -> shape: (n_frames, n_mfcc)
                    # Normalize data:
                    # NOTE: mfcc array is a 2D matrix of shape: (n_mfcc, n_frames).
                    # Rows represent features (MFCC), while columns represent samples (frames).
                    # We need to apply z-score normalization feature-wise, so considering all
                    # the frames related to a specific MFCC feature.
                    # So, we need to normalize along axis=1 (horizontally across columns / frames). 
                    # mfcc = zscore(mfcc, axis=1)
                    mean = np.mean(mfcc, axis=1).reshape(-1, 1) # + 1e-8
                    std_dev = np.std(mfcc, axis=1)
                    std_dev = np.array([1. if std_i == 0. else std_i for std_i in std_dev]).reshape(-1, 1)
                    mfcc = (mfcc - mean) / std_dev
                    # mfcc = mfcc.T
                
                if self.plot:
                    self.plot_mfcc(name + " Processed", mfcc, sample_rate, time=time_len)

                mfcc = mfcc.reshape(1, mfcc.shape[0], mfcc.shape[1], 1)

                print("MFCC new shape: {}".format(mfcc.shape))

                if mfcc_data is None:
                    mfcc_data = mfcc
                else:
                    mfcc_data = np.append(mfcc_data, mfcc, axis=0)
                # labels.append(1. if label is 'covid' else 0.)
                labels.append(label)

            else:
                print("Cannot load {}".format(file_path))
        
        # stop = input("Stop")
        return mfcc_data, np.array(labels)

    def run(self):
        colnames = ['name', 'label']
        audio_names = []
        audio_labels = []

        audio_samples = pd.read_csv(self.labels_file, names=colnames, header=None)
        
        covid_samples = audio_samples.loc[audio_samples['label'] == "covid"]
        other_samples = audio_samples.loc[audio_samples['label'] == "not_covid"]

        print("covid_samples: {}".format(covid_samples))
        print("other_samples: {}".format(other_samples))

        # shuffle df and take the needed samples
        covid_samples = covid_samples.sample(frac=1, random_state=rand_seed).reset_index(drop=True)[:self.num_covid]
        other_samples = other_samples.sample(frac=1, random_state=rand_seed).reset_index(drop=True)[:self.num_notcovid]

        print("covid_samples shuffled: {}".format(covid_samples))
        print("other_samples shuffled: {}".format(other_samples))

        audio_names.extend(covid_samples.name.tolist())
        audio_names.extend(other_samples.name.tolist())

        audio_labels.extend(covid_samples.label.tolist())
        audio_labels.extend(other_samples.label.tolist())

        print("audio_names: {}".format(audio_names))
        print("audio_labels: {}".format(audio_labels))

        audio_labels = [1. if l == "covid" else 0. for l in audio_labels]

        self.num_classes = len(np.unique(audio_labels))

        # AUGMENTATION NOTES:
        # 1st Method:
        # By default we use 30 non-covid samples and 17 covid samples, so in total 47 samples.
        # To solve the issue of the low number of samples, we adopt the following approach:
        # 1) Split the initial dataset into train (60 % => 28 samples) and test (40 % => 19)
        # 2) Further split the second set (test) into validation (40 % of 19 => 8) 
        #    and test (60 % of 19 => 11) sets
        # 3) Apply augmentation techiniques to train samples to increase its size (x4)
        #
        # 2nd Method:
        # Apply data augmentation to the original dataset to increase its size and then the split 
        # the augmented dataset into training, validation and test.
        # The main disadvantage is that usually validation and test sets should not be augmented,
        # but in this case it should be fine since the initial size of the sataset is too low to
        # obtain good performance and above all to avoid overfitting.
        # In this approach we adopt a train-validation-test split equal to 70-15-15

        # Use 2nd method for data augmentation:
        if self.augment:
            audio_names, audio_labels = self.augment_audio(audio_names, audio_labels)
        
        train_names, test_names, train_labels, test_labels = train_test_split(
            audio_names, audio_labels, test_size=0.3, random_state=rand_seed, stratify=audio_labels)
        
        test_names, val_names, test_labels, val_labels = train_test_split(
            test_names, test_labels, test_size=0.5, random_state=rand_seed, stratify=test_labels)
        
        print("train_names: {}, num_train: {}".format(train_names, len(train_names)))
        print("train_labels: {}".format(train_labels))
        assert len(train_names) == len(train_labels)
        print("val_names: {}, num_val: {}".format(val_names, len(val_names)))
        print("val_labels: {}".format(val_labels))
        assert len(val_names) == len(val_labels)
        print("test_names: {}, num_test: {}".format(test_names, len(test_names)))
        print("test_labels: {}".format(test_labels))
        assert len(test_names) == len(test_labels)
        
        X_train, y_train = self.load_data(train_names, train_labels)
        X_test, y_test = self.load_data(test_names, test_labels)
        X_val, y_val = self.load_data(val_names, val_labels)

        print("X_train shape: {}".format(X_train.shape))
        print("y_train: {}, shape: {}".format(y_train, y_train.shape))
        print("X_val shape: {}".format(X_val.shape))
        print("y_val: {}, shape: {}".format(y_val, y_val.shape))
        print("X_test shape: {}".format(X_test.shape))
        print("y_test: {}, shape: {}".format(y_test, y_test.shape))

        input_shape = X_train[0].shape
        if self.model_name == "tinyconv1":
            model = self.build_tinyconv1_model(input_shape)
        elif self.model_name == "tinyconv2":
            model = self.build_tinyconv2_model(input_shape)
        elif self.model_name == "tinyconv3":
            model = self.build_tinyconv3_model(input_shape)
        elif self.model_name == "tinyconv4":
            model = self.build_tinyconv4_model(input_shape)

        opt = optimizers.Adam(lr=self.lr)
        model.compile(loss='sparse_categorical_crossentropy', metrics=['accuracy'], optimizer=opt)
        model.summary()
        plot_model(model, to_file='plots/model.png', show_shapes=True, show_layer_names=True)

        # stop = input("Stop")

        # Calculate pre-training accuracy
        # score = model.evaluate(X_test, y_test, verbose=1)
        # print("Pre-training accuracy (on test_set): {}".format(100 * score[1]))

        # Train the model
        checkpoint = ModelCheckpoint(filepath='saved_models/keras_model.h5',
                                     monitor='val_accuracy',
                                     verbose=1,
                                     save_best_only=True)
        # es_callback = EarlyStopping(monitor='val_loss', patience=10, verbose=1)
        # reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.001, patience=7, verbose=1, mode='auto', min_delta=0.001, cooldown=1, min_lr=0)
        
        start = time() # datetime.now()
        history = model.fit(X_train, y_train, batch_size=self.batch_size,
                            epochs=self.epochs,
                            validation_data=(X_val, y_val),
                            shuffle=False,
                            callbacks=[checkpoint],
                            verbose=2)
        # history = model.fit(x_train, y_train, batch_size=num_batch_size, epochs=num_epochs, validation_data=(x_test, y_test), shuffle=True, callbacks=[checkpointer], verbose=1)
        duration = time() - start  # datetime.now() - start
        print("Training completed in {:.2f} s".format(duration))

        # Evaluating the model on the training and testing set
        score = model.evaluate(X_train, y_train, verbose=0)
        print("Training accuracy: {:.3f}".format(score[1] * 100))

        score = model.evaluate(X_test, y_test, verbose=0)
        print("Testing accuracy: {:.3f}".format(score[1] * 100))

        # Plots and reports
        self.plot_history(history)

        y_pred = np.argmax(model.predict(X_test), axis=1)
        print("Target labels (test set): {}".format(y_test))
        print("Predicted labels (test set): {}".format(y_pred))
        cm = confusion_matrix(y_test, y_pred)
        self.plot_confusion_matrix(cm, self.target_names)
        self.plot_classification_report(y_test, y_pred)

    def build_tinyconv1_model(self, input_shape):
        model = Sequential()
        
        model.add(Input(shape=input_shape, name="input"))
        model.add(Conv2D(filters=8, kernel_size=(10, 8), strides=2, padding="same",
                         activation='relu', name="conv"))
        # model.add(BatchNormalization(name="bn"))
        model.add(Dropout(0.5, name="dropout"))
        model.add(Flatten(name="flatten"))
        model.add(Dense(self.num_classes, activation='softmax', name="dense"))

        return model
    
    def build_tinyconv2_model(self, input_shape):
        model = Sequential()
        
        model.add(Input(shape=input_shape, name="input"))
        
        model.add(Conv2D(filters=32, kernel_size=(20, 8), strides=1, padding="same",
                         activation='relu', name="conv1"))
        # model.add(BatchNormalization(name="bn"))
        model.add(Dropout(0.25, name="drop1"))
        model.add(MaxPooling2D(pool_size=(2, 2), name="pool1"))

        model.add(Conv2D(filters=64, kernel_size=(10, 4), strides=1, padding="same",
                         activation='relu', name="conv2"))
        # model.add(BatchNormalization(name="bn"))
        model.add(Dropout(0.5, name="drop2"))
        model.add(MaxPooling2D(pool_size=(2, 2), name="pool2"))

        model.add(Flatten(name="flatten"))

        model.add(Dense(self.num_classes, activation='softmax', name="dense"))

        return model
    
    def build_tinyconv3_model(self, input_shape):
        model = Sequential()
        
        model.add(Input(shape=input_shape, name="input"))
        
        model.add(Conv2D(filters=32, kernel_size=(3, 3),
                         activation='relu', name="conv1"))
        model.add(Conv2D(filters=64, kernel_size=(3, 3),
                         activation='relu', name="conv2"))
        model.add(MaxPooling2D(pool_size=(2, 2), name="pool1"))
        model.add(Dropout(0.25, name="drop1"))
        model.add(Flatten(name="flatten"))
        model.add(Dense(units=128, activation='relu', name="dense1"))
        model.add(Dropout(0.5, name="drop2"))
        model.add(Dense(self.num_classes, activation='softmax', name="dense2"))

        return model
    
    def build_tinyconv4_model(self, input_shape):
        model = Model()
        
        input = Input(shape=input_shape)
        
        conv1 = Conv2D(filters=32, kernel_size=(3, 3), strides=2, padding="same", name="conv1")(input)
        bn1 = BatchNormalization(name="bn1")(conv1)
        # relu1 = ReLU(name="relu1")(bn1)
        relu1 = ReLU(max_value=6, name="relu1")(bn1)
        
        dw2 = DepthwiseConv2D(kernel_size=(3, 3), strides=2, padding="same", name="dw2")(relu1)
        bn2 = BatchNormalization(name="bn2")(dw2)
        # relu2 = ReLU(name="relu2")(bn2)
        relu2 = ReLU(max_value=6, name="relu2")(bn2)

        pw3 = Conv2D(filters=64, kernel_size=(1, 1), strides=1, padding="same", name="pw3")(relu2)
        bn3 = BatchNormalization(name="bn3")(pw3)
        # relu3 = ReLU(name="relu3")(bn3)
        relu3 = ReLU(max_value=6, name="relu3")(bn3)

        pool4 = AveragePooling2D(pool_size=(3, 3))(relu3)
        flatten = Flatten(name="flatten")(pool4)
        dense5 = Dense(self.num_classes, activation='softmax', name="dense5")(flatten)

        model = Model(inputs=input, outputs=dense5)

        return model

    def plot_history(self, history):
        # Plot training & validation accuracy
        plt.plot(history.history['accuracy'])
        plt.plot(history.history['val_accuracy'])
        plt.title('Model accuracy')
        plt.ylabel('Accuracy')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Validation'], loc='upper left')
        # plt.show()
        plt.savefig('plots/accuracy.png')
        plt.clf()

        # Plot training & validation loss values
        plt.plot(history.history['loss'])
        plt.plot(history.history['val_loss'])
        plt.title('Model loss')
        plt.ylabel('Loss')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Validation'], loc='upper left')
        # plt.show()
        plt.savefig('plots/loss.png')
        plt.close()

    def plot_classification_report(self, x_test, y_test):
        # Print
        print(classification_report(x_test, y_test, target_names=self.target_names))
        # Save data
        clsf_report = pd.DataFrame(classification_report(y_true = x_test, y_pred = y_test, output_dict=True, target_names=self.target_names)).transpose()
        clsf_report.to_csv('plots/classification_report.csv', index= True)

    def plot_confusion_matrix(self, cm, target_names, title='Confusion matrix', cmap=None, normalize=True):
        matplotlib.rcParams.update({'font.size': 22})
        accuracy = np.trace(cm) / float(np.sum(cm))
        misclass = 1 - accuracy

        if cmap is None:
            cmap = plt.get_cmap('Blues')

        plt.figure(figsize=(14, 12))
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        plt.title(title)
        plt.colorbar()

        if target_names is not None:
            tick_marks = np.arange(len(target_names))
            plt.xticks(tick_marks, target_names, rotation=45)
            plt.yticks(tick_marks, target_names)

        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

        thresh = cm.max() / 1.5 if normalize else cm.max() / 2
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            if normalize:
                plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                         horizontalalignment="center",
                         color="white" if cm[i, j] > thresh else "black")
            else:
                plt.text(j, i, "{:,}".format(cm[i, j]),
                         horizontalalignment="center",
                         color="white" if cm[i, j] > thresh else "black")

        plt.tight_layout()
        plt.ylabel('True label')
        plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
        plt.savefig('plots/confusion_matrix.png', bbox_inches = "tight")
        plt.close()

    def plot_mfcc(self, filename, mfcc, sr, time=None):
        
        # plt.figure(figsize=(14, 8))
        # S_dB = librosa.power_to_db(mfcc, ref=np.max)
        # librosa.display.specshow(S_dB, y_axis='mel', x_axis='time')
        #
        # librosa.display.specshow(librosa.amplitude_to_db(mfcc, ref=np.max), y_axis='mel', x_axis='time', sr=sr)
        # plt.colorbar(format='%+2.0f dB')
        # plt.title(filename)
        # plt.ylabel('MFCC Coefficients', fontsize=12)
        # plt.xlabel('Time (s)', fontsize=12)
        # plt.tight_layout()

        plt.figure(figsize=(10, 6))
        # plt.imshow(mfcc, cmap=cm.jet, aspect=0.3, extent=[0, time, 0, mfcc.shape[0]])
        sns.heatmap(mfcc)
        plt.ylabel('MFCC Coefficients', fontsize=12)
        plt.xlabel('# Frames', fontsize=12)
        plt.title(filename)
        
        # plt.savefig('mfcc/'+ filename + '.png', bbox_inches=None, pad_inches=0)
        plt.show()

    def reduce_noise_median(self, y, sr):
        """
            NOISE REDUCTION USING MEDIAN:
            receives an audio matrix,
            returns the matrix after gain reduction on noise
            https://github.com/dodiku/noise_reduction/blob/master/noise.py
        """
        y = sp.signal.medfilt(y, 3)
        return y

    def vocal_removal(self, y, sr):
        """
        https://librosa.github.io/librosa_gallery/auto_examples/plot_vocal_separation.html
        """
        idx = slice(*librosa.time_to_frames([0, 10], sr=sr))
        S_full, phase = librosa.magphase(librosa.stft(y))
        S_filter = librosa.decompose.nn_filter(S_full,
                                       aggregate=np.median,
                                       metric='cosine',
                                       width=int(librosa.time_to_frames(2, sr=sr)))
        S_filter = np.minimum(S_full, S_filter)
        margin_i, margin_v = 2, 10
        power = 2
        mask_i = librosa.util.softmask(S_filter,
                                       margin_i * (S_full - S_filter),
                                       power=power)

        mask_v = librosa.util.softmask(S_full - S_filter,
                                       margin_v * S_filter,
                                       power=power)

        S_foreground = mask_v * S_full
        S_background = mask_i * S_full

        # Convert back to audio
        audio_minus_vocals = librosa.core.istft(S_background[:, idx])

        return audio_minus_vocals

In [None]:
audioClassifier = AudioClassifier(dataset_path="dataset/audio/",
                                  labels_file="dataset/cough_dataset.csv",
                                  model_name="tinyconv1",
                                  sample_rate=16000, frame_length=1024, hop_length=512,
                                  num_mfcc_frames=32, num_mfcc=12, num_fft=1024, num_mels=40, 
                                  normalize=False, batch_size=4, lr=0.001, epochs=50, plot=False)

In [None]:
audioClassifier.run()

covid_samples:                                               name  label
0   cough-shallow-3CwioNQVDBQ6CttLyFVRJpMpVHk2.wav  covid
1                      pos-0421-084-cough-m-50.wav  covid
2     cough-heavy-6T43bddKoKfG7MwnJWvrPZSsyrc2.wav  covid
3     cough-heavy-hNAGUEhL2Nh7V89at3yFEjQYo6c2.wav  covid
4     cough-heavy-hte8VptUoGVFEqvHpbh5brgfcNP2.wav  covid
5   cough-shallow-QjBZv868nydJzk0ZzwgKDHSG6Q82.wav  covid
6                      pos-0421-092-cough-m-53.wav  covid
7                      pos-0422-096-cough-m-31.wav  covid
8                      pos-0421-094-cough-m-51.wav  covid
9   cough-shallow-hNAGUEhL2Nh7V89at3yFEjQYo6c2.wav  covid
10                     pos-0421-086-cough-m-65.wav  covid
11                     pos-0421-087-cough-f-40.wav  covid
12    cough-heavy-3CwioNQVDBQ6CttLyFVRJpMpVHk2.wav  covid
13  cough-shallow-6T43bddKoKfG7MwnJWvrPZSsyrc2.wav  covid
14    cough-heavy-QjBZv868nydJzk0ZzwgKDHSG6Q82.wav  covid
15  cough-shallow-hte8VptUoGVFEqvHpbh5brgfcNP2.wav  covid