In [None]:
#-------------------------------------------------------------------------------------JUPYTER NOTEBOOK SETTINGS-------------------------------------------------------------------------------------
from IPython.core.display import display, HTML                                    
display(HTML("<style>.container { width:100% !important; }</style>"))  
import IPython.display as display

In [None]:
import os
import gc
import re
import librosa
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from joblib import dump, load

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.manifold import TSNE
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix, classification_report, f1_score, recall_score, precision_score, accuracy_score

import tensorflow as tf
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Layer, Input, Conv1D, MaxPooling1D, Dropout, Flatten, Dense, BatchNormalization
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import Callback, ReduceLROnPlateau, ModelCheckpoint, EarlyStopping 

import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px

In [None]:
# Define GradientReversalLayer for adversarial models
@tf.keras.utils.register_keras_serializable()
class GradientReversalLayer(Layer):
    def __init__(self, lambda_=1.0, **kwargs):
        super(GradientReversalLayer, self).__init__(**kwargs)
        self.lambda_ = lambda_

    @tf.custom_gradient
    def call(self, x):
        def grad(dy):
            return -self.lambda_ * dy
        return x, grad

    def get_config(self):
        config = super().get_config()
        config.update({"lambda_": self.lambda_})
        return config

In [None]:
# Custom objects dictionary
custom_objects = {"GradientReversalLayer": GradientReversalLayer}

# Load the model
model = load_model('saved_data/adversarial-training_custom-cnn_final_model.keras', custom_objects=custom_objects)

# Load history and test data
history = load('saved_data/adversarial-training_custom-cnn_training_history.joblib')
x_test, y_test = load('saved_data/test_data.joblib')
x_test = np.array(x_test, dtype=np.float32)

# Load gender data
genders_train, genders_val, genders_test = load('saved_data/genders_data.joblib') 

# Label encoder setup
all_labels = ['battery', 'description', 'environment', 'greeting', 'health', 'noise', 'nutrition', 'silence', 'sun', 'water']
label_encoder = LabelEncoder()
label_encoder.fit(all_labels)

# Predictions with the model
predictions = model.predict(x_test)
y_pred_task = predictions[0]  # Task output predictions
y_pred_gender = predictions[1]  # Gender output predictions

# Get the class with the highest probability for each sample
y_pred = np.argmax(y_pred_task, axis=1)

# Decode the integer predictions back to string labels
y_pred_labels = label_encoder.inverse_transform(y_pred)

# Encode y_test if it's not already in integer form
if isinstance(y_test[0], str):
    y_test_encoded = label_encoder.transform(y_test)
else:
    y_test_encoded = y_test

# Decode the integer y_test labels back to string labels for the DataFrame
y_test_labels = label_encoder.inverse_transform(y_test_encoded)

# Ensure y_pred_gender is 1-dimensional
gender_labels = np.where(y_pred_gender.flatten() > 0.5, 'female', 'male')

# Convert genders_test to string labels
correct_gender_labels = np.where(np.array(genders_test) == 1, 'female', 'male')

# Create DataFrame with correct label, predicted label, predicted gender, and correct gender
results_df = pd.DataFrame({
    'Correct Label': y_test_labels,
    'Predicted Label': y_pred_labels,
    'Correct Gender': correct_gender_labels,
    'Predicted Gender': gender_labels
})

# Display the full DataFrame in a scrollable view
display.display(display.HTML(results_df.to_html(index=False)))

# Calculate metrics for command classification
command_accuracy = accuracy_score(y_test_labels, y_pred_labels)
command_precision = precision_score(y_test_labels, y_pred_labels, average='weighted')
command_recall = recall_score(y_test_labels, y_pred_labels, average='weighted')
command_f1 = f1_score(y_test_labels, y_pred_labels, average='weighted')

# Calculate metrics for gender prediction
gender_accuracy = accuracy_score(correct_gender_labels, gender_labels)
gender_precision = precision_score(correct_gender_labels, gender_labels, average='binary', pos_label='female')
gender_recall = recall_score(correct_gender_labels, gender_labels, average='binary', pos_label='female')
gender_f1 = f1_score(correct_gender_labels, gender_labels, average='binary', pos_label='female')

# Output the results
print("Command Classification Metrics:")
print(f"Accuracy: {command_accuracy:.4f}")
print(f"Precision: {command_precision:.4f}")
print(f"Recall: {command_recall:.4f}")
print(f"F1 Score: {command_f1:.4f}")

print("\nGender Prediction Metrics:")
print(f"Accuracy: {gender_accuracy:.4f}")
print(f"Precision: {gender_precision:.4f}")
print(f"Recall: {gender_recall:.4f}")
print(f"F1 Score: {gender_f1:.4f}")

In [None]:
print(history.keys())

In [None]:
# Plot confusion matrix with string labels for readability
fig, ax = plt.subplots(figsize=(8, 8))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Greens', ax=ax, 
            xticklabels=label_encoder.classes_, 
            yticklabels=label_encoder.classes_)
ax.set_xlabel('Predicted Labels')
ax.set_ylabel('True Labels')
ax.set_title('Confusion Matrix')
plt.show()

# Loss and Accuracy Per Epoch plots
plt.figure(figsize=(12, 6))

# Plot for Loss
plt.subplot(1, 2, 1)
plt.plot(history['loss'], label='Training Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.title('Loss Per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Plot for Accuracy
plt.subplot(1, 2, 2)
plt.plot(history['task_output_accuracy'], label='Training Accuracy')
plt.plot(history['val_task_output_accuracy'], label='Validation Accuracy')
plt.title('Accuracy Per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()