In [None]:
#-------------------------------------------------------------------------------------JUPYTER NOTEBOOK SETTINGS-------------------------------------------------------------------------------------
from IPython.core.display import display, HTML                                    
display(HTML("<style>.container { width:100% !important; }</style>"))  

In [None]:
import os
import gc
import re
import librosa
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from joblib import dump, load

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.manifold import TSNE
from sklearn.preprocessing import LabelEncoder

import tensorflow as tf
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Input, Reshape, UpSampling2D, Conv2D, GlobalAveragePooling2D, Dropout, Dense
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import Callback, ReduceLROnPlateau, ModelCheckpoint, EarlyStopping 
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras import mixed_precision

import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px

In [None]:
x_train, y_train = load('saved_data/train_data.joblib')
x_val, y_val = load('saved_data/val_data.joblib')
x_test, y_test = load('saved_data/test_data.joblib')
print("All data has been loaded properly!")

In [None]:
print("Train data:", x_train.shape, x_train.dtype)
print("Validation data:", x_val.shape, x_val.dtype)
print("Test data:", x_test.shape, x_test.dtype)

# Check for any NaN or inf values in your dataset
print("NaNs in train:", np.isnan(x_train).any())
print("NaNs in validation:", np.isnan(x_val).any())
print("NaNs in test:", np.isnan(x_test).any())

print("Infs in train:", np.isinf(x_train).any())
print("Infs in validation:", np.isinf(x_val).any())
print("Infs in test:", np.isinf(x_test).any())

In [None]:
# ONEHOT ENCODING THE LABELS
label_encoder = LabelEncoder()
y_train_encoded = label_encoder.fit_transform(y_train)
y_val_encoded = label_encoder.transform(y_val)
y_test_encoded = label_encoder.transform(y_test)

# Convert labels to one-hot encoding
y_train_onehot = to_categorical(y_train_encoded)
y_val_onehot = to_categorical(y_val_encoded)
y_test_onehot = to_categorical(y_test_encoded)

In [None]:
def load_latest_weights(weights_dir, file_pattern):
    """Load the latest weights based on the file modification time."""
    # List all files in the directory that match the pattern
    all_weights = [os.path.join(weights_dir, f) for f in os.listdir(weights_dir) if file_pattern in f]
    # Find the most recent file by sorting based on modification time
    latest_weights = max(all_weights, key=os.path.getmtime, default=None)
    if latest_weights:
        print(f"Loading weights from {latest_weights}")
        return latest_weights
    else:
        print("No weights file found.")
        return None

class SaveWeightsCallback(Callback):
    def __init__(self, save_freq, filepath):
        super(SaveWeightsCallback, self).__init__()
        self.save_freq = save_freq
        self.filepath = filepath
    
    def on_epoch_end(self, epoch, logs=None):
        # Check if the current epoch number is a multiple of the save frequency
        if (epoch + 1) % self.save_freq == 0:
            self.model.save_weights(self.filepath.format(epoch=epoch + 1))

In [None]:
# Set up mixed precision policy
mixed_precision.set_global_policy('mixed_float16')

In [None]:
# Load MobileNetV2 pre-trained on ImageNet
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(96, 96, 3))

# Freeze the convolutional base
base_model.trainable = False

# Define the model
model = Sequential([
    Input(shape=(13, 332)),  # Explicit Input layer
    Reshape((13, 332, 1)),   # Reshape to add a channel dimension
    UpSampling2D(size=(8, 1)),  # Upsample to increase height
    Conv2D(3, (3, 3), activation='relu', padding='same'),  # Convert to 3 channels
    tf.keras.layers.Resizing(96, 96),  # Resize to the expected input shape of MobileNetV2
    base_model,  # Add the pre-trained MobileNetV2
    GlobalAveragePooling2D(),  # Pooling layer to reduce spatial dimensions
    Dropout(0.2),  # Dropout for regularization
    Dense(64, activation='relu'),  # Fully connected layer
    Dropout(0.5),  # Another dropout for regularization
    Dense(y_train_onehot.shape[1], activation='softmax')  # Output layer
])

# Configure the optimizer, loss, and metrics
optimizer = Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

# Setup callbacks
early_stopping_monitor = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=1)
reduce_lr_on_plateau = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=0.000001, verbose=1)
weights_saver = SaveWeightsCallback(save_freq=50, filepath='saved_data/mobilenetv2_finetuned_weights_epoch_{epoch}.weights.h5')


# Segment-based training setup
num_epochs_per_stage = 50
total_epochs = 500
current_epoch = 0
all_history = []

while current_epoch < total_epochs:
    try:
        # Load the latest model weights if available
        try:
            latest_weights_file = load_latest_weights('saved_data', '.weights.h5')
            if latest_weights_file:
                model.load_weights(latest_weights_file)
        except Exception as e:
            print("Loading weights failed:", e)
        
        # Train the model for a stage
        history = model.fit(
            x_train,
            y_train_onehot, 
            epochs=current_epoch + num_epochs_per_stage,
            batch_size=512,
            validation_data=(x_val, y_val_onehot),
            callbacks=[weights_saver, early_stopping_monitor, reduce_lr_on_plateau],
            initial_epoch=current_epoch,
            verbose=1  
        )
        
        # Append segment history to the total history
        all_history.append(history.history)
        
        # Update the current epoch count
        current_epoch += len(history.history['loss'])
        
        # Optionally perform garbage collection
        gc.collect()

        # Check if early stopping was triggered
        if early_stopping_monitor.stopped_epoch > 0:
            print(f"Early stopping triggered at epoch {current_epoch}")
            break

    except Exception as e:
        print("An error occurred during training:", e)
        break

In [None]:
# Save final model
model.save('saved_data/mobilenetv2_finetuned_model.keras')

# Concatenate all history segments into one dictionary if there are any segments
if all_history:
    final_history = {key: np.concatenate([seg[key] for seg in all_history]) for key in all_history[0]}
    # Save the final training history
    dump(history.history, 'saved_data/mobilenetv2_finetuned_training_history.joblib')
else:
    print("No training history was recorded.")