In [None]:
import numpy as np
import sklearn
from sklearn.model_selection import train_test_split
import tensorflow as tf
import keras
from keras.utils import Sequence
from keras.utils import to_categorical
import os
import sys


sys.path.append('..')

from preprocess.spectrogram import mel_spectrogram 
from preprocess.wav_helper import trim_audio_to_np_float


In [None]:


DATASET_DIR = '../../' + '.tstdata/dataset'
SPEC_SIZE = (128, 376)  # or whatever your function outputs

LABELS = sorted(os.listdir(DATASET_DIR + '/train'))


In [None]:

class AudioDataGenerator(Sequence):
    def __init__(self, data_dir, batch_size=32, input_shape=(128, 128), shuffle=True):
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.input_shape = input_shape
        self.shuffle = shuffle

        # Collect file paths
        self.filepaths = []
        self.labels = []
        self.label_names = sorted(os.listdir(data_dir))
        self.label_to_idx = {name: i for i, name in enumerate(self.label_names)}

        for label in self.label_names:
            label_dir = os.path.join(data_dir, label)
            for fname in os.listdir(label_dir):
                if fname.lower().endswith(".wav"):
                    self.filepaths.append(os.path.join(label_dir, fname))
                    self.labels.append(self.label_to_idx[label])

        self.indices = np.arange(len(self.filepaths))
        if self.shuffle:
            np.random.shuffle(self.indices)

    def __len__(self):
        # number of batches per epoch
        return len(self.filepaths) // self.batch_size

    def __getitem__(self, index):
        batch_idx = self.indices[index * self.batch_size:(index + 1) * self.batch_size]
        batch_files = [self.filepaths[i] for i in batch_idx]
        batch_labels = [self.labels[i] for i in batch_idx]

        X, y = self.__load_batch(batch_files, batch_labels)
        return X, y

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)

    def __load_batch(self, batch_files, batch_labels):
        X = np.zeros((len(batch_files), *self.input_shape, 1), dtype=np.float32)
        y = np.zeros((len(batch_files), len(self.label_names)), dtype=np.float32)

        for i, (path, label_idx) in enumerate(zip(batch_files, batch_labels)):
            try:
                with open(path, 'rb') as f:
                    spec = mel_spectrogram(
                        trim_audio_to_np_float(f.read(), 0, 4, 4)
                    )
                    
                # normalize per-sample
                spec = spec.astype(np.float32)
                spec /= np.max(spec) if np.max(spec) > 0 else 1.0

                if spec.shape != self.input_shape:
                    print("Invalid shape", spec.shape, self.input_shape)
                    continue  # skip invalid shapes

                X[i, :, :, 0] = spec
                y[i, label_idx] = 1.0
            except Exception as e:
                print("Error processing", path, ":", e)

        return X, y


In [None]:
tf.config.optimizer.set_jit(False)
gpus = tf.config.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [None]:


model = keras.Sequential([
    keras.layers.InputLayer(shape=(*SPEC_SIZE, 1), dtype=np.float32),
    keras.layers.Conv2D(32, (3,3), activation='relu'),
    keras.layers.MaxPooling2D((2,2)),
    keras.layers.Conv2D(64, (3,3), activation='relu'),
    keras.layers.MaxPooling2D((2,2)),
    keras.layers.Conv2D(128, (3,3), activation='relu'),
    keras.layers.Flatten(),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(len(LABELS), activation='softmax')
])


model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)


In [None]:
BATCH_SIZE = 4

train_gen = AudioDataGenerator(DATASET_DIR + '/train', batch_size=BATCH_SIZE, input_shape=SPEC_SIZE)
val_gen   = AudioDataGenerator(DATASET_DIR + '/val', batch_size=BATCH_SIZE, input_shape=SPEC_SIZE, shuffle=False)


In [None]:

model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=30,
    batch_size=BATCH_SIZE
)
