<a href="https://drive.google.com/file/d/1-lmvLqHRoVztabnwQ8RbZuDhpsd1kmYY/view?usp=sharing" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TIMBRE

**Research project** for the course of *Selected Topics in Music and Acoustic Engineering* :

***Music Instrument Classification***

This project addresses the development of a system for automatic music instrument classification. The dataset
provided is the MedleyDBcollection1, which contains 196 professionally recorded multitrack recordings, including
individual stems corresponding to isolated instruments.
 
Students are tasked with designing a classification pipeline that either recognizes instruments in multitimbral mix
tures or classifies individual stems where typically one instrument is active. The project encourages a flexible
approach, allowing exploration of both isolated and polyphonic scenarios.

Key aspects to investigate include:
- Analyzing the robustness of instrument recognition systems when facing different levels of overlapping in
struments within a mixture.
- Studying the relationship between instrumentation and musical genre, as genre annotations are also available in the dataset.
- Exploring the use of co-occurrence matrices to model and understand typical combinations of instruments within different musical contexts.
The students should experiment with feature extraction techniques sensitive to timbral characteristics, such as
spectral descriptors and MFCCs, and assess the effectiveness of classification

### Team:
* Andrea Crisafulli
* Marco Porcella
* Giacomo De Toni
* Gianluigi Vecchini

## *Import libraries*:

In [1]:
# === Core Python & Scientific Computing ===
import numpy as np                # Numerical computing
import pandas as pd              # Data handling and manipulation
import matplotlib.pyplot as plt  # Plotting
from pathlib import Path         # File path handling
import scipy.signal as signal    # Signal processing tools

# === Audio Processing ===
import librosa                   # Audio analysis
import librosa.display           # Visualization for librosa outputs
import IPython.display as ipd    # For audio playback in notebooks

# === Scikit-learn: ML & Preprocessing ===
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, ConfusionMatrixDisplay  # Evaluation
from sklearn.decomposition import PCA         # Dimensionality reduction
from sklearn.preprocessing import scale, StandardScaler, MultiLabelBinarizer  # Data scaling & encoding
from sklearn.model_selection import train_test_split  # Dataset splitting
from sklearn.svm import SVC                    # Support Vector Classifier
from sklearn.neighbors import KNeighborsClassifier  # k-NN classifier
from sklearn.cluster import KMeans            # Clustering

# === Deep Learning: TensorFlow / Keras ===
import tensorflow as tf
from tensorflow import keras
from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint  # Training utilities
from keras.optimizers import Adam            # Optimizer for model training
from keras import layers, models

# === Optional: PyTorch (if used) ===
#import torch
#import torch.nn as nn
#import torch.nn.functional as F              # Functional API for building models

# === Others ===
import yaml                                   # Parsing metadata in YAML format
from collections import Counter               # Frequency counting for label analysis
from tqdm import tqdm                         # Progress bar for loops

# === Plotting Style ===
#plt.style.use("seaborn-v0_8")                 # Set default plotting style

EXECUTION SETTINGS

In [2]:
# Variables for data extraction
considerMixFiles = True
considerStemFiles = True
considerRawFiles = False

# Variable for OS recognition
OSys = 0  # 0 = WINDOWS, 1 = MACOS

In [3]:
# Since 83 different labels are present we can group togheter similar labels
labelGroupsDict = {
    'Main System': 'Main System',               # 12

    # Vocals 
    'male singer': 'vocals',                    # 82
    'female singer': 'vocals',                  # 57
    'male rapper': 'vocals',                    # 8
    'male speaker': 'vocals',                   # 2
    'vocalists': 'vocals',                      # 60 

    # Guitar
    'acoustic guitar': 'guitar',                # 50
    'clean electric guitar': 'guitar',          # 94
    'distorted electric guitar': 'guitar',      # 54
    'lap steel guitar': 'guitar',               

    # Small guitars
    'liuqin': 'similar guitars',               
    'banjo': 'similar guitars',
    'mandolin': 'similar guitars',              # 18
    'oud':'similar guitars',
    'zhongruan':'similar guitars',

    # Electric bass
    'electric bass': 'electric bass',           # 126

    # Violin
    'violin': 'violin',                         # 41
    'violin section': 'violin',                 # 28

    # Viola
    'viola': 'viola',                           # 15
    'viola section': 'viola',                   # 8

    # Cello
    'cello': 'cello',                           # 24
    'cello section': 'cello',

    # double bass
    'double bass': 'strings',                   # 33

    # string section
    'string section': 'strings',                # 12
    'erhu': 'strings',                          # 12

    # Trumpet
    'trumpet': 'brass',                         # 15
    'trumpet section': 'brass',

    # Trombone
    'trombone': 'trombone',
    'trombone section': 'trombone',

    # Trombone
    'french horn': 'french horn',
    'french horn section': 'french horn',

    # Tuba
    'tuba': 'tuba',

    # Brass
    'horn section': 'brass',
    'brass section': 'brass',                   # 16

    # Saxophone
    'saxophone': 'saxophone',
    'soprano saxophone': 'saxophone',
    'alto saxophone': 'saxophone',
    'tenor saxophone': 'saxophone',
    'baritone saxophone': 'saxophone',

    # Woodwinds
    'dizi': 'woodwinds',
    'flute': 'woodwinds',                       # 22
    'flute section': 'woodwinds',
    'piccolo': 'woodwinds',
    'clarinet': 'woodwinds',                    # 17
    'clarinet section': 'woodwinds',
    'bass clarinet': 'woodwinds',
    'oboe': 'woodwinds',
    'bassoon': 'woodwinds',                     # 11
    'bamboo flute' : 'woodwinds',

    # Drum set
    'drum set': 'drum set',                     # 131
    'snare drum': 'drum set',
    'kick drum': 'drum set',
    'bass drum': 'drum set',

    # Percussion
    'toms': 'percussion',
    'doumbek': 'percussion',
    'tabla': 'percussion',                      # 27
    'darbuka': 'percussion',
    'tambourine': 'percussion',
    'shaker': 'percussion',
    'bongo': 'percussion',
    'cymbal': 'percussion',                     # 14
    'timpani': 'percussion',                    # 13
    'auxiliary percussion': 'percussion',       # 36
    'claps': 'percussion',
    'gong': 'percussion',
    'gu':'percussion',
    'drum machine': 'percussion',                 # 32
    'scratches': 'percussion',

    # Xilofono
    'chimes': 'xilofono',
    'glockenspiel':'xilofono',
    'vibraphone':'xilofono',                    # 16

    # Harps
    'guzheng':'harps',
    'harp':'harps',
    'yangqin':'harps',                          # 12

    # Piano
    'piano': 'piano',                           # 86
    'electric piano': 'piano', 

    # Accordion
    'accordion': 'accordion',                   # 10

    # Harmonica
    'harmonica': 'harmonica',                   # 4
    
    # Keyboard and piano
    'tack piano': 'keyboard',                   # 18
    'melodica': 'keyboard',         
    'sampler': 'keyboard',
    'synthesizer': 'keyboard',                  # 74 

    # fx/processed sound
    'fx/processed sound':'fx/processed sound',  # 63 
}

### *Import audio data*:

In [4]:
windows = 0
macOS = 1
errorPath = False

if OSys == windows:
    basePath = Path("E:/MedleyDB")                         # For windows
elif OSys == macOS:
    basePath = Path("/Volumes/Extreme SSD/MedleyDB")        # For mac
else:
    errorPath = True

audioPath = basePath / "Audio"
data = []

if (considerMixFiles or considerStemFiles or considerRawFiles and not errorPath):
    # Iterates over directories in the melodyDB/Audio folder
    for songDir in audioPath.iterdir():
        labelArray = []
        
        # Security check to skip not directory items
        if not songDir.is_dir():
            continue
        
        songName = songDir.name
        yamlFilePath = audioPath / songDir / f"{songName}_METADATA.yaml" # Path to YAML metadata file
        
        # Opens YAML metadata file in read mode
        with open(yamlFilePath, "r") as f:
            metadata = yaml.safe_load(f)
        
        # Recovers stems from metadata and stores in dictionary
        stemsData = metadata.get("stems", {})
        
        # Iterates over stems
        for stemId, stem in stemsData.items():
            instrumentData = [] # Empty data for raw paths
            
            rawData = stem.get("raw", {})
            
            # import Raw files (SKIP FOR NOW)
            if considerRawFiles:
                # Iterates over raw items to store the relative paths
                for rawId, raw in rawData.items():
                    rawPath = songDir /  f"{songName}_RAW" / raw.get("filename")
                    
                    # Checks for valid files
                    if(not rawPath.name.startswith(".")):
                        rawData = {
                            "song": songName,
                            "songPath": audioPath / songDir,
                            "label": labelGroupsDict[stem.get("instrument")],
                            "filePath": rawPath
                        }
                        
                        data.append(rawData)
                
            # Creates new data
            if considerStemFiles:
                stemData = {
                    "song": songName,
                    "songPath": songDir,
                    "label": labelGroupsDict[stem.get("instrument")],
                    "filePath": songDir / f"{songName}_STEMS" / stem.get("filename")
                }
                
                # Appends to data 
                data.append(stemData)
            
            if considerMixFiles:
                labelArray.append(labelGroupsDict[stem.get("instrument")])
        
        if considerMixFiles:
            # Format
            labelArray = np.unique(labelArray)
            labelFormatted = "|".join(sorted(set(labelArray)))
            
            mixData = {
                "song": songName,
                "songPath": audioPath / songDir,
                "label": labelFormatted,
                "filePath": songDir / f"{songName}_MIX.wav"
            }
            
            data.append(mixData)
    # Create DataFrame
    df = pd.DataFrame(data)
    print(f"Loaded {len(df)} audio files.")

    # String convertion to list
    df["labelList"] = df["label"].str.split("|")

    mlbAllDataset = MultiLabelBinarizer()
    audioLabelsBinary = mlbAllDataset.fit_transform(df["labelList"])
    audioLabelsBinary = np.asarray(audioLabelsBinary)
elif errorPath:
    print("ERROR: defined OS is not supported.")
else:
    print("ERROR: no data to extract with current setup.")

FileNotFoundError: [WinError 3] Impossibile trovare il percorso specificato: 'E:\\MedleyDB\\Audio'

In [None]:
#Head of dataset
df.head()

In [None]:
# Tail of dataset
df.tail()

In [None]:
# Info of dataset
df.info()

In [None]:
# Mostra tutte le righe e colonne senza troncamento
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)


df_labels = pd.DataFrame(audioLabelsBinary, columns=mlbAllDataset.classes_)
df_labels.sum().sort_values(ascending=False)


In [None]:
# Storing variables
audioFiles = []
audioLabels = []

# Extract paths and labels
for _, row in df.iterrows():
    audioFiles.append(row["filePath"])
    audioLabels.append(row["labelList"])
    

# Security check
if(len(audioFiles) == len(audioLabels)):
    print(f"Extracted files and labels for a total lenght of {len(audioFiles)}")
else:
    print("Error in dataset")

In [None]:
# Select number of classes to extract to form the partial dataset (MAX = 82)
n_classes = 5

In [None]:
# Flatten all labels e conta frequenze
all_labels = sum(df["labelList"], [])
label_counts = Counter(all_labels)

if n_classes <= label_counts.total():
    # Prendi le 10 classi più frequenti
    top_labels = [label for label, _ in label_counts.most_common(n_classes)]
    print("Top labels:", top_labels)

    # Filtra righe dove almeno una label è in top_labels
    df_subset = df[df["labelList"].apply(lambda labels: any(label in top_labels for label in labels))]

    # Estrai audio e label
    audioFilesSubset = df_subset["filePath"].tolist()
    audioLabelsSubset = df_subset["labelList"].tolist()

    mlbPartialDataset = MultiLabelBinarizer(classes=top_labels)
    audioLabelsSubsetBinary = mlbPartialDataset.fit_transform(df_subset["labelList"])
    audioLabelsSubsetBinary = np.asarray(audioLabelsSubsetBinary)

    if len(audioFilesSubset) == len(audioLabelsSubset):
        print(f"Extracted {len(audioFilesSubset)} samples from top {n_classes} labels")
    else:
        print("Mismatch in extracted data")

    print(f"Subset binary label matrix shape: {audioLabelsSubsetBinary.shape}")
else:
    print(f"ERROR: selected n_classes of {n_classes} exceeds the total number of classes in DF ({label_counts})")

In [None]:
executionMode = 1   # 0 = all dataset, 1 = partial dataset

allDataset = 0
partialDataset = 1

In [None]:
# if elif block to choose execution mode between all dataset and partial dataset
if executionMode == allDataset:
    mlb = mlbAllDataset                  # mlb to use
    labelsToLoad = audioLabelsBinary     # audioLabels to use
    audioFilesToExtract = audioFiles     # audio files to load
elif executionMode == partialDataset:
    mlb = mlbPartialDataset
    labelsToLoad = audioLabelsSubsetBinary
    audioFilesToExtract = audioFilesSubset

signals = []

timeExtraction = 10  # in seconds
samplingRate = 22050
num_samples = int(samplingRate * timeExtraction)
minAmplitude = 0.5

# Extraction of files via librosa load (TQDM to show progress)
for x in tqdm(audioFilesToExtract, desc="Loading audio files..."):
    y, _ = librosa.load(x, sr=samplingRate)
    
    # Normalizzazione
    if np.max(np.abs(y)) > 0:
        y = y / np.max(np.abs(y))

    # Trova primo indice significativo
    if np.any(y > minAmplitude):
        start_index = np.argmax(y > minAmplitude)
    else:
        start_index = 0

    # Calcola fine dell’estrazione
    end_index = start_index + num_samples
    
    if end_index <= len(y):
        y = y[start_index:end_index]
    else:
        # Prova a traslare lo start indietro se possibile
        if len(y) >= num_samples:
            start_index = len(y) - num_samples
            y = y[start_index:]
        else:
            # Troppo corto: pad con zeri alla fine
            y = y[start_index:]
            padding_needed = num_samples - len(y)
            y = np.pad(y, (0, padding_needed), 'constant')

    y = y.astype(np.float16)
    signals.append(y)

In [None]:
melSpegrams = []

# Iterates over signals, normalizes them and computes mel spectrograms via librosa feature
for signal in tqdm(signals, desc="Processing audio signals..."):

    # Creation of mel spectrogram
    S = librosa.feature.melspectrogram(y=signal, sr=22050)
    S_dB = librosa.power_to_db(S, ref=np.max)
    melSpegrams.append(S_dB)

In [None]:
# Plot of spectrograms
import IPython.display

for i in range(0, len(signals), int(len(signals)/5)):
    print("\n\n-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-")
    
    original_labels = mlb.inverse_transform(np.array([labelsToLoad[i]]))
    print("Labels:", original_labels)

    IPython.display.display(IPython.display.Audio(signals[i], rate=22050))

    plt.figure(figsize=(10,6))
    librosa.display.specshow(melSpegrams[i], sr=22050, x_axis='time', y_axis='mel', fmax=22050/2)
    plt.clim(-80,None)
    plt.colorbar()

    filename = str(audioFilesToExtract[i]).split("\\")[-1]
    
    plt.title(f'{filename} (data #{i})')
    plt.show()

The user now has to choose which data he wants to load

In [None]:
melSpegrams = np.asarray(melSpegrams)
labelsToLoad = np.asarray(labelsToLoad)

In [None]:
from sklearn.utils import shuffle
melSpegrams, labelsToLoad = shuffle(melSpegrams, labelsToLoad, random_state=1234)

# Split into Train (70%) and Temp (30%)
# X = mel spectrograms
# y = label
from skmultilearn.model_selection import iterative_train_test_split

# Convert X to numpy array (es. melSpectrograms devono essere np.array)
X = np.array(melSpegrams)
y = np.array(labelsToLoad)

# Split: Train (70%), Temp (30%)
X_train, y_train, X_temp, y_temp = iterative_train_test_split(X, y, test_size=0.30)

# Split: Validation (15%), Test (15%)
X_val, y_val, X_test, y_test = iterative_train_test_split(X_temp, y_temp, test_size=0.5)
# Summary
print(f"Train samples:      {len(X_train)}")
print(f"Validation samples: {len(X_val)}")
print(f"Test samples:       {len(X_test)}")
print(f"Test lables:       {len(y_train)}")
print(f"Validation lables: {len(y_val)}")
print(f"Test lables:       {len(y_test)}")


print("Train:", np.sum(y_train, axis=0))
print("Val:  ", np.sum(y_val, axis=0))
print("Test: ", np.sum(y_test, axis=0))

In [None]:
train_counts = np.sum(y_train, axis=0)
val_counts = np.sum(y_val, axis=0)
test_counts = np.sum(y_test, axis=0)

ruby = '#9B111E'
sapphire = '#0F52BA'
emerald = '#50C878'

labelsChart = mlb.classes_
x = np.arange(len(labelsChart))  # posizione per ogni classe
width = 0.25  # larghezza delle barre

plt.figure(figsize=(14, 5))

# Barre
plt.bar(x - width, train_counts, width, label='Train', color=ruby)
plt.bar(x, val_counts, width, label='Validation', color=sapphire)
plt.bar(x + width, test_counts, width, label='Test', color=emerald)

# Asse x e titoli
plt.xticks(x, labelsChart, rotation=45, ha='right')
plt.xlabel('Class')
plt.ylabel('Occurrences')
plt.title(f'Class distribution in Dataset (Total elements: {np.sum(labelsToLoad)})')
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
# === Define input shape and number of output classes ===
inputShape = (128, melSpegrams[0].shape[1], 1)  # (n_mels, time_frames, channels)
numClasses = labelsToLoad.shape[1]              # number of multilabel classes

# === Build CNN model ===
modelCNN = models.Sequential([

    # Input
    layers.Input(shape=inputShape),

    # === Block 1 ===
    layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
    layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
    layers.MaxPooling2D((3, 3), padding='same'),
    layers.Dropout(0.25),

    # === Block 2 ===
    layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
    layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
    layers.MaxPooling2D((3, 3), padding='same'),
    layers.Dropout(0.25),

    # === Block 3 ===
    layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
    layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
    layers.MaxPooling2D((3, 3), padding='same'),
    layers.Dropout(0.25),

    # === Block 4 ===
    layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
    layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
    layers.GlobalMaxPooling2D(),
    layers.Dropout(0.25),

    # === Fully Connected ===
    layers.Dense(1024, activation='relu'),
    layers.Dropout(0.5),

    # === Output Layer (sigmoid for multilabel) ===
    layers.Dense(numClasses, activation='sigmoid')
])

# === Compile the model ===
from keras.metrics import BinaryAccuracy, AUC
modelCNN.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=[BinaryAccuracy(name='binary_accuracy'), AUC(multi_label=True, name='auc')]
)

# === Summary ===
modelCNN.summary()

# Optional: show classes
print(f"Number of classes: {numClasses}")
print(f"Class names: {mlb.classes_}")

In [None]:
# Path to save logs and models
csvLogPath = 'training_log.csv'
checkpointPath = 'best_model.h5'

# CSVLogger: logs every epoch to CSV
csvLogger = CSVLogger(csvLogPath, append=True)

# EarlyStopping: stop if val_loss doesn't improve after 100 epochs
earlyStop = EarlyStopping(
    monitor='val_binary_accuracy',
    patience=100,
    restore_best_weights=True,
    verbose=1
)

# ModelCheckpoint: save best model based on val_accuracy
checkpoint = ModelCheckpoint(
    filepath=checkpointPath,
    monitor='val_binary_accuracy',
    save_best_only=True,
    verbose=1
)

# Bundle them
callbacks = [csvLogger, earlyStop, checkpoint]

In [None]:
from keras.optimizers import Adam
from keras.metrics import BinaryAccuracy, AUC

# Compile the model
opt = tf.keras.optimizers.Adam(learning_rate=0.001)
modelCNN.compile(
    optimizer=opt,
    loss='binary_crossentropy',
    metrics=[BinaryAccuracy(name='binary_accuracy'), AUC(multi_label=True, name='auc')]
)

batchSize=32
epochs=300

history = modelCNN.fit(X_train, y_train, validation_data=(X_val, y_val), batch_size=batchSize, epochs=epochs, verbose=0, callbacks = callbacks)


In [None]:
plt.figure(figsize=(20,8))

plt.subplot(1,2,1)
plt.plot(history.history['binary_accuracy']) 
plt.plot(history.history['val_binary_accuracy'])
plt.title('Model binary Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Binary Accuracy')
plt.legend(['Train', 'Validation'])
#plt.ylim(0, 3);

plt.subplot(1,2,2)
plt.plot(history.history['loss']) 
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(['Train', 'Validation'])
#plt.ylim(0, 2);

#best validation accuracy
bestValAccuracy = np.max(history.history['val_binary_accuracy'])
print('Best validation accuracy: ', bestValAccuracy)

In [None]:
# Esegui evaluate e visualizza il risultato
results = modelCNN.evaluate(X_test, y_test, verbose=0)
print(results)  # Stampa i risultati per vedere quanti valori vengono restituiti

# Se sono più di 2 valori, puoi selezionare quello che ti interessa
testLoss = results[0]
testAccuracy = results[1]  # Oppure index corretti se ci sono più metriche

print(f"Test Loss: {testLoss:.4f}")
print(f"Test Accuracy: {testAccuracy:.4f}")

In [None]:
modelCNN.load_weights('best_model.h5', by_name=False)

resultsTest = modelCNN.evaluate(X_test, y_test)
print('Test Loss: {} \nTest Accuracy: {}'.format(resultsTest[0], resultsTest[1]))

resultsVal = modelCNN.evaluate(X_val, y_val)
print('Val Loss: {} \nVal Accuracy: {}'.format(resultsVal[0], resultsVal[1]))

In [None]:
# Predict the labels of the test set
threshold = 0.5
y_pred_probs = modelCNN.predict(X_test)                     # returns probabilities for each label (from sigmoid outputs)
y_pred_binary = (y_pred_probs > threshold).astype(int)      # if the probability of a class is higher then treshold then consider it active

In [None]:
# Confusion matrix considering the MIX as a separate lable:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Convert y_test back to normal label names or "MIX"
y_test_simplified = []
for row in y_test:
    active_indices = np.where(row == 1)[0] # indices of the MLB lable in which the there is a 1 and not a 0
    if len(active_indices) == 1:
        y_test_simplified.append(mlb.classes_[active_indices[0]]) 
    else:
        y_test_simplified.append("MIX")

# Convert y_pred_binary back to normal label names or "MIX"
y_pred_simplified = []
for row in y_pred_binary:
    active_indices = np.where(row == 1)[0]
    if len(active_indices) == 1:
        y_pred_simplified.append(mlb.classes_[active_indices[0]]) 
    else:
        y_pred_simplified.append("MIX")

# Print confusion matrix
allLabels = list(mlb.classes_) + ["MIX"]

cm = confusion_matrix(y_test_simplified, y_pred_simplified, labels=allLabels)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=allLabels)
disp.plot(cmap="Blues", xticks_rotation=45)

In [None]:
cmLimited = cm[40:60,40:60]
disp = ConfusionMatrixDisplay(confusion_matrix=cmLimited, display_labels=allLabels[40:60])
disp.plot(cmap="Blues", xticks_rotation=45)

In [None]:
import seaborn as sns
from sklearn.metrics import multilabel_confusion_matrix
# Confusion matrix plotted for each class, using the multilables
conf_matrices = multilabel_confusion_matrix(y_test, y_pred_binary)

class_names = mlb.classes_

for i, cm in enumerate(conf_matrices):
    plt.figure(figsize=(4, 3))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False,
                xticklabels=["Pred 0", "Pred 1"],
                yticklabels=["True 0", "True 1"])
    plt.title(f"Confusion Matrix for class: {class_names[i]}")
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.tight_layout()
    plt.show()
    
for i, cm in enumerate(conf_matrices):
    tn, fp, fn, tp = cm.ravel()
    print(f"{class_names[i]} — TP: {tp}, FP: {fp}, FN: {fn}, TN: {tn}")

In [None]:
# Classification report
from sklearn.metrics import classification_report

report = classification_report(
    y_test,
    y_pred_binary,
    target_names=mlb.classes_,
    zero_division=0  
)

print(report)