In [None]:
import sys
import numpy as np
from sklearn.model_selection import train_test_split
from keras import mixed_precision
from keras.callbacks import EarlyStopping
from keras.models import Model
from keras.layers import (
    Input,
    Conv2D,
    MaxPooling2D,
    Flatten,
    Dense,
    Dropout,
    concatenate,
    BatchNormalization,
)

In [None]:
# Enable mixed precision
mixed_precision.set_global_policy("mixed_float16")

In [None]:
sys.path.append("../helpers")
from ProjectData import ProjectData
from model_pipeline import DataGenerator, plot_history

prd = ProjectData()

In [None]:
data = prd.get_data_paths()["melSpectrogram_mfcc"]
map_labels = {i: key for i, key in enumerate(prd.get_audio_paths())}
num_classes = len(map_labels)

In [None]:
X = []
y = []
for item in data:
    key, path = list(item.items())[0]
    X.append(path)
    y.append(key)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, shuffle=True
)
X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train, test_size=0.2, random_state=42, shuffle=True
)

In [None]:
batch_size = 16
input_shape_mel = (128, 1293, 1)
input_shape_mfccs = (13, 1293, 1)

In [None]:
train_generator = DataGenerator(
    X_train,
    y_train,
    input_shape_mel,
    input_shape_mfccs,
    num_classes=num_classes,
    batch_size=batch_size,
)
val_generator = DataGenerator(
    X_val,
    y_val,
    input_shape_mel,
    input_shape_mfccs,
    num_classes=num_classes,
    batch_size=batch_size,
)
test_generator = DataGenerator(
    X_test,
    y_test,
    input_shape_mel,
    input_shape_mfccs,
    num_classes=num_classes,
    batch_size=batch_size,
)

In [None]:
def create_model(input_shape_mel, input_shape_mfccs, num_classes):
    # Mel spectrogram branch
    input_mel = Input(shape=input_shape_mel, name="input_mel")
    x_mel = Conv2D(32, (3, 3), activation="relu", padding="same")(input_mel)
    x_mel = BatchNormalization()(x_mel)
    x_mel = MaxPooling2D((2, 2))(x_mel)
    x_mel = Conv2D(64, (3, 3), activation="relu", padding="same")(x_mel)
    x_mel = BatchNormalization()(x_mel)
    x_mel = MaxPooling2D((2, 2))(x_mel)
    x_mel = Conv2D(128, (3, 3), activation="relu", padding="same")(x_mel)
    x_mel = BatchNormalization()(x_mel)
    x_mel = MaxPooling2D((2, 2))(x_mel)
    x_mel = Flatten()(x_mel)

    # MFCC branch
    input_mfccs = Input(shape=input_shape_mfccs, name="input_mfccs")
    x_mfccs = Conv2D(32, (3, 3), activation="relu", padding="same")(input_mfccs)
    x_mfccs = BatchNormalization()(x_mfccs)
    x_mfccs = MaxPooling2D((2, 2))(x_mfccs)
    x_mfccs = Conv2D(64, (3, 3), activation="relu", padding="same")(x_mfccs)
    x_mfccs = BatchNormalization()(x_mfccs)
    x_mfccs = MaxPooling2D((2, 2))(x_mfccs)
    x_mfccs = Conv2D(128, (3, 3), activation="relu", padding="same")(x_mfccs)
    x_mfccs = BatchNormalization()(x_mfccs)
    x_mfccs = MaxPooling2D((2, 2))(x_mfccs)
    x_mfccs = Flatten()(x_mfccs)

    # Concatenate the outputs of both branches
    concatenated = concatenate([x_mel, x_mfccs])

    # Fully connected layers
    x = Dense(256, activation="relu")(concatenated)
    x = BatchNormalization()(x)
    x = Dropout(0.7)(x)
    x = Dense(128, activation="relu")(x)
    x = BatchNormalization()(x)
    x = Dropout(0.55)(x)
    output = Dense(num_classes, activation="softmax")(x)

    model = Model(inputs=[input_mel, input_mfccs], outputs=output)
    return model

In [None]:
model = create_model(input_shape_mel, input_shape_mfccs, num_classes)
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

In [None]:
model.summary()

In [None]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=10, restore_best_weights=True
)

In [None]:
history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=50,  # Adjust the number of epochs as needed
    callbacks=[early_stopping],
)

In [None]:
test_loss, test_accuracy = model.evaluate(test_generator)
print(f"Test Loss: {test_loss}")
print(f"Test Accuracy: {test_accuracy}")

In [None]:
plot_history(history, step=2)