In [3]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization

from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.model_selection import train_test_split

# Load your STFT and labels data
X_stft = np.load('X_stft_181.npy')
y_labels = np.load('bd1.npy')

# Adjust labels to start from 0
y_labels -= y_labels.min()

# Verify the unique labels after adjustment
unique_labels = np.unique(y_labels)
output_classes = len(unique_labels)

# Split the STFT data into real and imaginary parts
X_stft_real = X_stft.real
X_stft_imag = X_stft.imag
X_stft_combined = np.stack((X_stft_real, X_stft_imag), axis=-1)  # Shape: (850, 257, 25, 2)

# Split data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(X_stft_combined, y_labels, test_size=0.2, random_state=42)

# Function for basic manual data augmentation
def augment_data(X, y):
    X_augmented = []
    y_augmented = []
    for i in range(len(X)):
        # Original
        X_augmented.append(X[i])
        y_augmented.append(y[i])
        
        # Time Shift
        shifted = np.roll(X[i], shift=np.random.randint(1, 5), axis=1)
        X_augmented.append(shifted)
        y_augmented.append(y[i])
        
        # Frequency Shift
        shifted = np.roll(X[i], shift=np.random.randint(1, 5), axis=0)
        X_augmented.append(shifted)
        y_augmented.append(y[i])
        
    return np.array(X_augmented), np.array(y_augmented)

# Apply augmentation to the training data
X_train_augmented, y_train_augmented = augment_data(X_train, y_train)

# Define a simpler CVNN model with fewer layers and increased regularization
def create_cvnn(input_shape, output_classes):
    model = models.Sequential()
    
    # First and only convolutional layer
    model.add(layers.Conv2D(16, (3, 3), padding="same", input_shape=input_shape))
    model.add(layers.Activation('relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Dropout(0.2))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Dropout(0.2))

    # Flatten and fully connected layers
    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation="relu", kernel_regularizer=tf.keras.regularizers.l2(0.01)))
    model.add(layers.Dense(32, activation="relu", kernel_regularizer=tf.keras.regularizers.l2(0.01)))
    model.add(layers.Dropout(0.4))
    
    # Output layer
    model.add(layers.Dense(output_classes, activation="softmax"))
    
    return model

# Create and compile the model
input_shape = X_train.shape[1:]  # (257, 25, 2)
cvnn_model = create_cvnn(input_shape, output_classes)
cvnn_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), 
                   loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Early stopping to prevent overfitting
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

# Train the model
cvnn_model.fit(X_train_augmented, y_train_augmented, epochs=30, batch_size=10, validation_data=(X_val, y_val), callbacks=[early_stopping])

# Evaluate the model
loss, accuracy = cvnn_model.evaluate(X_val, y_val)
print(f"Model accuracy on validation set: {accuracy:.2f}")


Epoch 1/30
[1m204/204[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 59ms/step - accuracy: 0.2499 - loss: 2.2168 - val_accuracy: 0.2412 - val_loss: 1.8150
Epoch 2/30
[1m204/204[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 61ms/step - accuracy: 0.2692 - loss: 1.7704 - val_accuracy: 0.2412 - val_loss: 1.7465
Epoch 3/30
[1m204/204[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 56ms/step - accuracy: 0.3522 - loss: 1.6956 - val_accuracy: 0.3529 - val_loss: 1.7603
Epoch 4/30
[1m204/204[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 72ms/step - accuracy: 0.4214 - loss: 1.5881 - val_accuracy: 0.3824 - val_loss: 1.6570
Epoch 5/30
[1m204/204[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 65ms/step - accuracy: 0.4678 - loss: 1.5316 - val_accuracy: 0.4588 - val_loss: 1.6295
Epoch 6/30
[1m204/204[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 59ms/step - accuracy: 0.5261 - loss: 1.4092 - val_accuracy: 0.4588 - val_loss: 1.5791
Epoch 7/30
[1m2