In [None]:
import json
import os.path
import pickle

import librosa
import numpy as np
import pandas as pd
from bidict import bidict
from keras import Sequential
from keras.src.callbacks import ModelCheckpoint
from keras.src.utils import to_categorical
from matplotlib import pyplot as plt
from python_speech_features import mfcc
from sklearn.metrics import accuracy_score
from tqdm import tqdm

from models.instrument_first_models import CNN_One, RNN_One, CNN_Two
from utils.config import IRMAS_MFCC_Config

from utils.instrument_data import InstrumentDataManipulator, DataPreprocessor, IRMASDataManipulator
from scipy.io import wavfile

from sklearn.utils import compute_class_weight
from keras.models import load_model


#### Load General Config
with open("../config.json") as json_data_file:
    general_config = json.load(json_data_file)

dataset_path = general_config["IRMAS_Training_Path_Alt"]
external_pickle_path = general_config["External_Pickle_Path_Alt"]

CLASS_DICTIONARY = bidict({
    0: "cel",
    1: "cla",
    2: "flu",
    3: "gac",
    4: "gel",
    5: "org",
    6: "pia",
    7: "sax",
    8: "tru",
    9: "vio",
    10: "voi"
})
#### Load DataFrame information and Class Distribution
df = pd.read_csv('../statistics/irmas.csv')
df.set_index('filename', inplace=True)

classes = list(np.unique(df.label))
class_dist = df.groupby(['label'])['length'].sum()

fig, ax = plt.subplots()
ax.set_title('Class Distribution', y=1.08)

colors = plt.cm.tab20(np.arange(len(class_dist)) / len(class_dist))

ax.pie(class_dist, labels=class_dist.index, autopct='%1.1f%%', shadow=False, startangle=90, colors=colors)
ax.axis('equal')
plt.show()

n_samples = int(df['length'].sum()/0.1) # a tenth of a second

prob_dist = class_dist / class_dist.sum()  # probability distribution of each instrument being picked

print(df.index)
#### Fitting and saving the models
config = IRMAS_MFCC_Config(model_path="../instrument_models/non_rand/IRMASconv_trunc_1sec.keras", pickle_path=external_pickle_path + "\\non_rand\\IRMASconv_trunc_1sec.p", step=1)
data_manipulator = IRMASDataManipulator(config)

model_wrapper = None
y_flat = None
X, y = None, None

X, y = data_manipulator.build_trucated_data(dataset_path=dataset_path, df=df, classes=classes, input_type="librosa")
print("X shape : ", X.shape)
print("y shape : ", y.shape)  # one-hot encoded class matrix

y_flat = np.argmax(y, axis=1)
input_shape = (X.shape[1], X.shape[2], 1)
print("CNN Input Shape : ", input_shape)

model_wrapper = CNN_One(input_shape=input_shape, output_shape=len(classes))
model_wrapper.create_model()
checkpoint = ModelCheckpoint(filepath=config.model_path, monitor='val_acc', verbose=1, mode="max", save_best_only=True,
                             save_weights_only=False)

model_wrapper.model.fit(X, y, epochs=10, batch_size=32, shuffle=True, validation_split=0.1, callbacks=[checkpoint])

model_wrapper.model.save(config.model_path)