In [None]:
'''
Author: Steven Binder, The Photonics and Soft Robotics Lab, The University of Georgia
Date:   08-05-2025
'''

import tensorflow as tf
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os, os.path
import warnings 
import math
import librosa
import gc
import psutil

from tensorflow.keras.layers import Input, Dense,concatenate, Conv2D, Add, BatchNormalization,SpatialDropout2D, Dropout, Flatten, GlobalAveragePooling2D,MaxPooling2D
from tensorflow.keras.callbacks import LearningRateScheduler, ModelCheckpoint
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Sequential, Model
from sklearn.model_selection import train_test_split
from sklearn.metrics import recall_score, precision_score, confusion_matrix,accuracy_score,f1_score,roc_curve, auc
from sklearn.preprocessing import label_binarize

warnings.filterwarnings('ignore')

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

In [None]:
def reshape_features_and_labels(data,labels):
    samples,time,feat= data.shape
    reshaped_data = np.zeros((samples * 3, time, 2), dtype=data.dtype)

    reshaped_data[0::3, :, 0]=data[:, :, 0]  
    reshaped_data[0::3, :, 1]=data[:, :, 1]  
    reshaped_data[1::3, :, 0]=data[:, :, 2]  
    reshaped_data[1::3, :, 1]=data[:, :, 3]  
    reshaped_data[2::3, :, 0] = data[:, :, 4]  
    reshaped_data[2::3, :, 1] = data[:, :, 5] 

    reshaped_labels = np.repeat(labels, 3)

    return reshaped_data, reshaped_labels

def compute_mel_spectrogram(data, sample_rate, n_mels, n_fft, hop_length):
    samples, time, features = data.shape

    spectrograms = np.zeros((samples, n_mels, time, features))

    for i in range(samples):
        for j in range(features):
            time_series = data[i, :, j]
            if np.any(np.isnan(time_series)):
                print(f"Warning: NaN values found in sample {i}, feature {j}")
                time_series = np.nan_to_num(time_series)
            if np.max(np.abs(time_series)) > 0:
                time_series = time_series / np.max(np.abs(time_series))
            mel_spec = librosa.feature.melspectrogram(
                y=time_series,
                sr=sample_rate,
                n_mels=n_mels,
                n_fft=n_fft,
                hop_length=hop_length,
                power=2.0
            )
            mel_db = librosa.power_to_db(mel_spec, ref=np.max)
            spectrograms[i, :, :, j] = mel_db

    return spectrograms

def augment_spectrogram(data,gaussian_noise_prob=0.8, gaussian_noise_std=0.1,
                       multiplicative_noise_prob=0.5, multiplicative_noise_range=(0.7, 1.3),
                       time_shift_prob=0.3, time_shift_max=20):
    augmented_data = np.copy(data)
    if np.random.random() < gaussian_noise_prob: # Adding Gaussian noise
        noise = np.random.normal(0, gaussian_noise_std, size=data.shape)
        augmented_data = augmented_data + noise

    if np.random.random() < multiplicative_noise_prob: # Adding multiplicative noise
        for c in range(data.shape[2]):
            scale_factor = np.random.uniform(multiplicative_noise_range[0],multiplicative_noise_range[1])
            augmented_data[:, :, c] = augmented_data[:, :, c] * scale_factor

    if np.random.random() < time_shift_prob: # Adding time shift
        shift_amount = np.random.randint(-time_shift_max, time_shift_max + 1)
        if shift_amount != 0:
            for c in range(data.shape[2]):
                if shift_amount > 0:  
                    augmented_data[:, shift_amount:, c] = augmented_data[:, :-shift_amount, c]
                    augmented_data[:, :shift_amount, c] = 0
                else: 
                    shift_amount = abs(shift_amount)
                    augmented_data[:, :-shift_amount, c] = augmented_data[:, shift_amount:, c]
                    augmented_data[:, -shift_amount:, c] = 0
    
    augmented_data = np.clip(augmented_data, 0, 1)
    
    return augmented_data

In [None]:
# prepare train and validation data

titles=['D01X.npy','D01Y.npy','D02X.npy','D02Y.npy','D03X.npy','D03Y.npy','D04X.npy','D04Y.npy','D05X.npy','D05Y.npy','D06X.npy','D06Y.npy',
        'D07X.npy','D07Y.npy','D08X.npy','D08Y.npy','D09X.npy','D09Y.npy','D10X.npy','D10Y.npy',
        'D01X_100.npy','D01Y_100.npy','D02X_100.npy','D02Y_100.npy','D03X_100.npy','D03Y_100.npy','D04X_100.npy','D04Y_100.npy',
        'D05X_100.npy','D05Y_100.npy','D06X_100.npy','D06Y_100.npy','D07X_100.npy','D07Y_100.npy','D08X_100.npy','D08Y_100.npy','D09X_100.npy','D09Y_100.npy','D10X_100.npy','D10Y_100.npy',
        'D01X_300.npy','D01Y_300.npy','D02X_300.npy','D02Y_300.npy','D03X_300.npy','D03Y_300.npy','D04X_300.npy','D04Y_300.npy',
        'D05X_300.npy','D05Y_300.npy','D06X_300.npy','D06Y_300.npy','D07X_300.npy','D07Y_300.npy','D08X_300.npy','D08Y_300.npy','D09X_300.npy','D09Y_300.npy','D10X_300.npy','D10Y_300.npy',
        'D01X_750.npy','D01Y_750.npy','D02X_750.npy','D02Y_750.npy','D03X_750.npy','D03Y_750.npy','D04X_750.npy','D04Y_750.npy',
        'D05X_750.npy','D05Y_750.npy','D06X_750.npy','D06Y_750.npy','D07X_750.npy','D07Y_750.npy','D08X_750.npy','D08Y_750.npy','D09X_750.npy','D09Y_750.npy','D10X_750.npy','D10Y_750.npy',
        'D01X_1000.npy','D01Y_1000.npy','D02X_1000.npy','D02Y_1000.npy','D03X_1000.npy','D03Y_1000.npy','D04X_1000.npy','D04Y_1000.npy',
        'D05X_1000.npy','D05Y_1000.npy','D06X_1000.npy','D06Y_1000.npy','D07X_1000.npy','D07Y_1000.npy','D08X_1000.npy','D08Y_1000.npy','D09X_1000.npy','D09Y_1000.npy','D10X_1000.npy','D10Y_1000.npy'
        ]


sam_rate = 5000  
n_mels =128
n_fft = 1024     
hop_length =512

batch_size = 10
all_processed_X = []
all_processed_Y = []

for i in range(0, len(titles), batch_size):
    batch_titles = titles[i:i+batch_size]
    k = 0
    total_X0, total_X1, total_X2 = None, None, None
    for file in batch_titles:
        if 'X' in file:
            with open(file, 'rb') as f:
                x0 = np.load(f)#[0:4]
                x1 = np.load(f)#[0:4] 
                x2 = np.load(f)#[0:4]
                if 'X_300' in file:
                    x0 = x0[:,40000:190000,:]
                    x1 = x1[:,40000:190000,:]
                    x2 = x2[:,40000:190000,:]
                if k == 0:
                    total_X0 = x0
                    total_X1 = x1
                    total_X2 = x2
                    k += 1
                else:
                    total_X0 = np.vstack((total_X0, x0))
                    total_X1 = np.vstack((total_X1, x1))
                    total_X2 = np.vstack((total_X2, x2))
                del x0, x1, x2
        if 'Y' in file:
            with open(file, 'rb') as f:
                y0 = np.load(f)
                y1 = np.load(f)
                y2 = np.load(f)
                if 'Y_300' in file:
                    y0 = y0[40000:190000]
                    y1 = y1[40000:190000]
                    y2 = y2[40000:190000]
    batch_total_X = []
    batch_total_Y = []
    if total_X0 is not None:
        for i in range(total_X0.shape[0]):
            batch_total_X.append(total_X0[i,:,:])
            batch_total_Y.append(y0[0])

        for i in range(total_X1.shape[0]):
            batch_total_X.append(total_X1[i,:,:])
            batch_total_Y.append(y1[0])

        for i in range(total_X2.shape[0]):
            batch_total_X.append(total_X2[i,:,:])
            batch_total_Y.append(y2[0])

    batch_total_X = np.array(batch_total_X)
    batch_total_Y = np.array(batch_total_Y)

    batch_total_X, batch_total_Y = reshape_features_and_labels(batch_total_X, batch_total_Y)

    batch_total_X = compute_mel_spectrogram(batch_total_X, sample_rate=sam_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)

    all_processed_X.append(batch_total_X)
    all_processed_Y.extend(batch_total_Y)

    del batch_total_X, batch_total_Y
    del total_X0, total_X1, total_X2
    gc.collect()

final_X = np.concatenate(all_processed_X, axis=0)
final_Y = np.array(all_processed_Y)

globMax= np.max(final_X)
globMin= np.min(final_X)

final_X=(final_X - globMin) / (globMax - globMin)
print(final_X.shape,final_Y.shape)

In [None]:
# prepare test data
titles=['D11X.npy','D11Y.npy',
        'D11X_100.npy','D11Y_100.npy',
        'D11X_300.npy','D11Y_300.npy',
        'D11X_750.npy','D11Y_750.npy',
        'D11X_1000.npy','D11Y_1000.npy',]

batch_size = 10
all_processed_X = []
all_processed_Y = []

for i in range(0, len(titles), batch_size):
    batch_titles = titles[i:i+batch_size]
    k = 0
    total_X0, total_X1, total_X2 = None, None, None
    for file in batch_titles:
        if 'X' in file:
            with open(file, 'rb') as f:
                x0 = np.load(f)#[0:4]
                x1 = np.load(f)#[0:4] 
                x2 = np.load(f)#[0:4]
                if 'X_300' in file:
                    x0 = x0[:,40000:190000,:]
                    x1 = x1[:,40000:190000,:]
                    x2 = x2[:,40000:190000,:]
                if k == 0:
                    total_X0 = x0
                    total_X1 = x1
                    total_X2 = x2
                    k += 1
                else:
                    total_X0 = np.vstack((total_X0, x0))
                    total_X1 = np.vstack((total_X1, x1))
                    total_X2 = np.vstack((total_X2, x2))
                del x0, x1, x2
        if 'Y' in file:
            with open(file, 'rb') as f:
                y0 = np.load(f)
                y1 = np.load(f)
                y2 = np.load(f)
                if 'Y_300' in file:
                    y0 = y0[40000:190000]
                    y1 = y1[40000:190000]
                    y2 = y2[40000:190000]
    batch_total_X = []
    batch_total_Y = []

    if total_X0 is not None:
        for i in range(total_X0.shape[0]):
            batch_total_X.append(total_X0[i,:,:])
            batch_total_Y.append(y0[0])

        for i in range(total_X1.shape[0]):
            batch_total_X.append(total_X1[i,:,:])
            batch_total_Y.append(y1[0])

        for i in range(total_X2.shape[0]):
            batch_total_X.append(total_X2[i,:,:])
            batch_total_Y.append(y2[0])

    batch_total_X = np.array(batch_total_X)
    batch_total_Y = np.array(batch_total_Y)

    batch_total_X, batch_total_Y = reshape_features_and_labels(batch_total_X, batch_total_Y)

    batch_total_X = compute_mel_spectrogram(batch_total_X, sample_rate=sam_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)

    all_processed_X.append(batch_total_X)
    all_processed_Y.extend(batch_total_Y)

    del batch_total_X, batch_total_Y
    del total_X0, total_X1, total_X2
    gc.collect()

valid_X = np.concatenate(all_processed_X, axis=0)
valid_Y = np.array(all_processed_Y)

globMax= np.max(valid_X)
globMin= np.min(valid_X)

valid_X=(valid_X - globMin) / (globMax - globMin)
print(valid_X.shape,valid_Y.shape)

In [None]:
# dataset creation

final_X = final_X.astype(np.float32)
valid_X = valid_X.astype(np.float32)

batch_size = 8

X_train, X_test, y_train, y_test = train_test_split(final_X, final_Y, test_size=0.3, stratify=final_Y)
print(X_train.shape, y_train.shape)

# Augment training data
chunk_size = 1000
noise_X_train = []

for aug_type in range(3):
    aug_samples = np.zeros_like(X_train)
    
    for i in range(0, len(X_train), chunk_size):
        chunk = X_train[i:i+chunk_size].copy()

        for j in range(len(chunk)):
            if aug_type == 0:
                aug_samples[i+j] = augment_spectrogram(chunk[j],gaussian_noise_prob=1.0, gaussian_noise_std=0.12,multiplicative_noise_prob=0.0,time_shift_prob=0.0)
            elif aug_type == 1:
                aug_samples[i+j] = augment_spectrogram(chunk[j],gaussian_noise_prob=0.3, gaussian_noise_std=0.05,multiplicative_noise_prob=1.0, multiplicative_noise_range=(0.8, 1.2), time_shift_prob=0.0)
            else:
                aug_samples[i+j] = augment_spectrogram(chunk[j],gaussian_noise_prob=0.3, gaussian_noise_std=0.05,multiplicative_noise_prob=0.3,time_shift_prob=1.0, time_shift_max=15)

        del chunk
        gc.collect()
    
    noise_X_train.append(aug_samples)

aug_data = np.concatenate(noise_X_train, axis=0)
aug_labels = np.tile(y_train, 3)

X_train = np.concatenate((X_train, aug_data), axis=0)
y_train = np.concatenate((y_train, aug_labels), axis=0)

del noise_X_train, aug_data, aug_labels
gc.collect()

print(X_train.shape, y_train.shape)

train_DS = tf.data.Dataset.from_tensor_slices((X_train, y_train))\
    .cache()\
    .shuffle(1000)\
    .batch(batch_size)\
    .prefetch(tf.data.AUTOTUNE)

valid_DS = tf.data.Dataset.from_tensor_slices((X_test, y_test))\
    .cache()\
    .batch(batch_size)\
    .prefetch(tf.data.AUTOTUNE)

test_DS = tf.data.Dataset.from_tensor_slices((valid_X, valid_Y))\
    .cache()\
    .batch(batch_size)\
    .prefetch(tf.data.AUTOTUNE)

del X_train, X_test
gc.collect()

process = psutil.Process()
print(f"Memory usage: {process.memory_info().rss / 1024 / 1024:.2f} MB")

In [None]:
# model parameters
drop=0.2
reg=0.0001
input_shape=(128,293,2) 
kernel_shape=3
layers=[32,64,128,256,512]

def model1(input_shape):
    model = Sequential()
    model.add(Conv2D(layers[0], (kernel_shape, kernel_shape), activation='relu', padding='same', kernel_regularizer=l2(reg), input_shape=input_shape))
    model.add(Conv2D(layers[0], (kernel_shape, kernel_shape), activation='relu', padding='same', kernel_regularizer=l2(reg)))
    model.add(BatchNormalization())
    model.add(MaxPooling2D((2, 2)))
    model.add(SpatialDropout2D(drop))

    model.add(Conv2D(layers[1], (kernel_shape, kernel_shape), activation='relu', padding='same', kernel_regularizer=l2(reg), input_shape=input_shape))
    model.add(Conv2D(layers[1], (kernel_shape, kernel_shape), activation='relu', padding='same', kernel_regularizer=l2(reg)))
    model.add(BatchNormalization())
    model.add(MaxPooling2D((2, 2)))
    model.add(SpatialDropout2D(drop))

    model.add(Conv2D(layers[2], (kernel_shape, kernel_shape), activation='relu', padding='same', kernel_regularizer=l2(reg), input_shape=input_shape))
    model.add(Conv2D(layers[2], (kernel_shape, kernel_shape), activation='relu', padding='same', kernel_regularizer=l2(reg)))
    model.add(BatchNormalization())
    model.add(MaxPooling2D((2, 2)))
    model.add(SpatialDropout2D(drop))

    model.add(Flatten())
    return model

def res_block(input_ten,filters,strides=(1,1)):
    x=Conv2D(filters,(kernel_shape,kernel_shape),padding='same',kernel_regularizer=l2(reg))(input_ten)
    x=BatchNormalization()(x)
    x=tf.keras.layers.Activation('relu')(x)
    x=Conv2D(filters,(kernel_shape,kernel_shape),padding='same',kernel_regularizer=l2(reg))(x)
    x=BatchNormalization()(x)
    if strides!=(1,1):
        input_ten=Conv2D(filters,(1,1),padding='same',kernel_regularizer=l2(reg))(input_ten)
        input_ten=BatchNormalization()(input_ten)
    x=Add()([x,input_ten])
    x=tf.keras.layers.Activation('relu')(x)
    return x

def model2(input_shape):
    input_ten=Input(shape=input_shape)
    x=Conv2D(layers[0],(7,7),strides=(2,2),padding='same',kernel_regularizer=l2(reg))(input_ten)
    x=BatchNormalization()(x)
    x=tf.keras.layers.Activation('relu')(x)
    x=MaxPooling2D((3,3),strides=(2,2),padding='same')(x)

    x=res_block(x,layers[0])
    x=res_block(x,layers[0])

    x=res_block(x,layers[1],strides=(2,2))
    x=res_block(x,layers[1])

    x=res_block(x,layers[2],strides=(2,2))
    x=res_block(x,layers[2])

    x=res_block(x,layers[3],strides=(2,2))
    x=res_block(x,layers[3])

    x=GlobalAveragePooling2D()(x)

    model=Model(input_ten,x)
    return model

def combo_model(input_shape):
    input_sig=Input(shape=input_shape)

    mod1=model1(input_shape)
    mod1_Out=mod1(input_sig)

    mod2=model2(input_shape)
    mod2_Out=mod2(input_sig)

    concat_features=concatenate([mod1_Out,mod2_Out])

    x=Dense(layers[4],activation='relu')(concat_features)
    x=Dropout(drop)(x)
    x=Dense(layers[3],activation='relu')(x)
    x=Dropout(drop)(x)
    x=Dense(layers[2],activation='relu')(x)
    x=Dropout(drop)(x)
    x=Dense(layers[1],activation='relu')(x)
    x=Dropout(drop)(x)
    x=Dense(layers[0],activation='relu')(x)
    x=Dropout(drop)(x)
    x=Dense(3,activation='softmax')(x)

    model=Model(input_sig,x)
    return model

model=combo_model(input_shape)

def lr_schedule(epoch):
    initial_lr = 0.00001
    drop_rate = 0.5
    epochs_drop = 20.0
    lr = initial_lr * math.pow(drop_rate, math.floor((1+epoch)/epochs_drop))
    return lr

lr_scheduler = LearningRateScheduler(lr_schedule)
checkpoint=ModelCheckpoint('best_model.keras', monitor='val_accuracy', save_best_only=True, mode='max')

model.compile(optimizer=Adam(learning_rate=0.00001, clipnorm=1.0, use_ema=True),loss='sparse_categorical_crossentropy',metrics=['accuracy'])

r = model.fit(train_DS,validation_data=(valid_DS),epochs=150,callbacks=[lr_scheduler,checkpoint])  

# View loss and accuracy curves
sns.set(rc={'figure.figsize':(20,5)})
fig,ax=plt.subplots(1,2)
sns.lineplot(x=range(len(r.history['loss'])), y=r.history['loss'], label='loss', ax=ax[0])
sns.lineplot(x=range(len(r.history['val_loss'])), y=r.history['val_loss'], label='val_loss', ax=ax[0])
sns.lineplot(x=range(len(r.history['accuracy'])), y=r.history['accuracy'], label='accuracy', ax=ax[1])
sns.lineplot(x=range(len(r.history['val_accuracy'])), y=r.history['val_accuracy'], label='val_accuracy', ax=ax[1])
ax[0].set_title('Loss')
ax[1].set_title('Accuracy')
ax[0].set_xlabel('Epochs')   

In [None]:
# Validation data metrics and confusion matrix
model=tf.keras.models.load_model('best_model.keras')

y_pred = model.predict(valid_DS)
y_pred_c = np.argmax(y_pred, axis=1)

accuracy = accuracy_score(y_test, y_pred_c)
recall_macro = recall_score(y_test, y_pred_c, average='macro')
precision_macro = precision_score(y_test, y_pred_c, average='macro')
f1DS = f1_score(y_test, y_pred_c, average='macro')

print(f"Validation Set Accuracy: {accuracy:.4f}")
print(f"Recall: {recall_macro:.4f}")
print(f"Precision: {precision_macro:.4f}")
print(f"F1 Score: {f1DS:.4f}")

target_names = [f'Class {i}' for i in range(model.output_shape[-1])]
plt.figure(figsize=(10, 8))
sns.heatmap(confusion_matrix(y_test, y_pred_c), annot=True, fmt='d', cmap='Blues',
            xticklabels=target_names, yticklabels=target_names)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Validation Set Confusion Matrix')
plt.show()

In [None]:
# Test data metrics and confusion matrix
y_pred = model.predict(test_DS)
y_pred_c = np.argmax(y_pred, axis=1)

accuracy = accuracy_score(y_test, y_pred_c)
recall_macro = recall_score(y_test, y_pred_c, average='macro')
precision_macro = precision_score(y_test, y_pred_c, average='macro')
f1DS = f1_score(y_test, y_pred_c, average='macro')

print(f"Test Set Accuracy: {accuracy:.4f}")
print(f"Recall: {recall_macro:.4f}")
print(f"Precision: {precision_macro:.4f}")
print(f"F1 Score: {f1DS:.4f}")

target_names = [f'Class {i}' for i in range(model.output_shape[-1])]
plt.figure(figsize=(10, 8))
sns.heatmap(confusion_matrix(y_test, y_pred_c), annot=True, fmt='d', cmap='Blues',
            xticklabels=target_names, yticklabels=target_names)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Test Set Confusion Matrix')
plt.show()

In [None]:
# AUC for validation set

y_pred = model.predict(valid_DS)
y_pred_c = np.argmax(y_pred, axis=1)

n_classes = model.output_shape[-1]
binary_y_test = label_binarize(y_test, classes=range(n_classes))

tpr = {}
fpr = {}
roc_auc = {}
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(binary_y_test[:, i], y_pred[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
    mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])

mean_tpr /= n_classes

fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

print(f"Macro-average AUC: {roc_auc['macro']:.4f}")

for i in range(n_classes):
    plt.plot(fpr[i], tpr[i], lw=2,label=f'Class {i} (AUC = {roc_auc[i]:.4f})')

plt.plot(fpr["macro"], tpr["macro"], color='navy', linestyle='--', lw=2,label=f'Macro-average (AUC = {roc_auc["macro"]:.4f})')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate (FPR)')
plt.ylabel('True Positive Rate (TPR)')
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)


In [None]:
# AUC for test set

y_pred = model.predict(test_DS)
y_pred_c = np.argmax(y_pred, axis=1)

n_classes = model.output_shape[-1]
binary_y_test = label_binarize(valid_Y, classes=range(n_classes))

tpr = {}
fpr = {}
roc_auc = {}
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(binary_y_test[:, i], y_pred[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
    mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])

mean_tpr /= n_classes

fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

print(f"Macro-average AUC: {roc_auc['macro']:.4f}")

for i in range(n_classes):
    plt.plot(fpr[i], tpr[i], lw=2,label=f'Class {i} (AUC = {roc_auc[i]:.4f})')

plt.plot(fpr["macro"], tpr["macro"], color='navy', linestyle='--', lw=2,label=f'Macro-average (AUC = {roc_auc["macro"]:.4f})')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate (FPR)')
plt.ylabel('True Positive Rate (TPR)')
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)
