
<br>
<h1 style="text-align:center; font-size:46px; font-weight:800; margin-bottom:0;">
Cybersecurity - IoT Intrusion Dataset
</h1>
<p style="text-align:center; font-size:22px; color:gray; margin-top:5px;">
Conv1D Neural Network for Time Series Analysis
</p>
<br>
<p style="text-align:center;">
  <img src="img/ai.jpg" style="width:85%; max-width:900px; border-radius:10px; box-shadow:0px 0px 12px rgba(0,0,0,0.25);">
</p>
<br>
<p style="text-align:center; font-size:16px; color:gray;">
Environment & Tools: Python â€¢ PyTorch â€¢ NumPy â€¢ Pandas â€¢ Scikit-Learn â€¢ Matplotlib
</p>
<br>


The dataset used in this analysis is publicly available from an open-source repository on Kaggle.  
You can access it here: [insert URL].

We will be using a **1D Convolutional Neural Network (Conv1D CNN)** to perform **time-series analysis** on the dataset.
Conv1D is particularly suited for sequential data, as it can capture temporal patterns across the input sequence of features.


<hr>
<h2 style="text-align:center; font-weight:700;">ðŸ“š Libraries Used in This Project</h2>
<p style="text-align:center; color:gray; font-size:15px;">Core tools used for data processing, modeling, and visualization.</p>
<br>
<table style="margin-left:auto; margin-right:auto; text-align:center;">
  <tr>
    <td><img src="img/numpy.jpg" width="60"><br><strong>NumPy</strong><br><span style="color:gray; font-size:13px;">Numerical<br>Computing</span></td>
    <td><img src="img/pandas.png" width="60"><br><strong>Pandas</strong><br><span style="color:gray; font-size:13px;">Data<br>Manipulation</span></td>
    <td><img src="img/pytorch.jpg" width="100"><br><strong>PyTorch</strong><br><span style="color:gray; font-size:13px;">Deep<br>Learning</span></td>
  </tr>
  <tr>
    <td><img src="img/scikitlearn.png" width="60"><br><strong>Scikit-Learn</strong><br><span style="color:gray; font-size:13px;">ML Tools<br>& Metrics</span></td>
    <td><img src="img/matplotlib.jpg" width="90"><br><strong>Matplotlib</strong><br><span style="color:gray; font-size:13px;">Visualization</span></td>
  </tr>
</table>
<br>
<hr>


Importing Libraries and Modules

In [None]:
"""
Created on Thu Oct 23 18:45:00 2025

@author: ybenjaminpcondori
"""

# System & OS utilities
import os

# Data manipulation
import numpy as np
import pandas as pd

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# Scikit-Learn
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc, precision_recall_fscore_support
from sklearn.metrics import (classification_report, f1_score, recall_score,
                             confusion_matrix, roc_auc_score, accuracy_score)
from sklearn.metrics import precision_score


# Visualization
import matplotlib.pyplot as plt
import matplotlib as mpl

# Importing Seaborn for enhanced visualizations
import seaborn as sns


## Formatting for Presentation purposes on Evaluation

Declaring Classes and Functions

In [None]:
# Convolutional Neural Network Definition
class Convolutional_Neural_Network(nn.Module):
    def __init__(self, num_classes, num_features):
        super().__init__()

        # Convolutional feature extractor
        self.conv1 = nn.Conv1d(1, 64, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(64)

        # Second convolutional block
        self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(128)
        
        # Third convolutional block
        self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(256)

        # Regularisation inside convolutional blocks
        self.conv_dropout = nn.Dropout1d(0.2)

        # Pooling
        self.pool = nn.MaxPool1d(2)
        self.global_pool = nn.AdaptiveAvgPool1d(1)

        # Activation
        self.relu = nn.ReLU()

        # Determine feature size dynamically
        with torch.no_grad():
            dummy = torch.zeros(1, 1, num_features)
            d = self._extract_features(dummy)
            self.flatten_dim = d.shape[1]

        # Reduced fully connected head (less memorisation)
        self.fc1 = nn.Linear(self.flatten_dim, 64)
        self.dropout = nn.Dropout(0.3)
        self.out = nn.Linear(64, num_classes)

    # Feature extraction method
    def _extract_features(self, x):

        # Convolutional layers with batch norm, ReLU, dropout, and pooling
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.conv_dropout(x)
        x = self.pool(x)

        # Second convolutional block
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.conv_dropout(x)
        x = self.pool(x)

        # Third convolutional block
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.global_pool(x)

        x = x.squeeze(-1)  # (batch, channels)
        return x
    
    # Forward pass
    def forward(self, x):
        x = x.unsqueeze(1)  # (batch, 1, features)
        x = self._extract_features(x)
        x = self.dropout(self.relu(self.fc1(x)))
        return self.out(x)

# Binary counting helper class
class DataPreprocessing:

    @staticmethod
    def count_binary(series):
        counts = series.value_counts()
        return pd.Series({'0': counts.get(0, 0), '1': counts.get(1, 0)})

## Data Preprocessing

Loading the Dataset

In [None]:
# Reading the CSV Dataset using Pandas
df = pd.read_csv("IoT_Intrusion.csv")

# Sripping whitespace and converting to lowercase
df.columns = df.columns.str.strip().str.lower()

# Displaying the first few rows of the dataframe
df.head()

Declaration of Columns: Network Protocols

In [None]:
# Analyzing binary columns
protocol_columns = ['http','https','dns','telnet','smtp','ssh','irc','tcp','udp','dhcp','arp','icmp','ipv','llc']
for col in protocol_columns:
    _ = df[col].value_counts()

Feature Importance: Number of instances of 0/1 in each feature

In [None]:
# Counting binary values in specified columns
counts_df = df[protocol_columns].apply(DataPreprocessing.count_binary)
print(counts_df)

# Printing label statistics
print("Binary label statistics:")
print("Unique label count:", df['label'].nunique())
print("Unique labels:", df['label'].unique())
print("Label counts:\n", df['label'].value_counts())

Caclulating Uniques Values, important for identifying rows that don't contribute to the feature prediction/classification

In [None]:
# -----------------------------
# Calculating unique values
# -----------------------------

print("\nUnique Value Counts")
print("=" * 50)

protocol_columns_nunique = df[protocol_columns].nunique()
print(protocol_columns_nunique)

Declaration of Columns: Signal Processing Columns

In [None]:
# Columns to analyze
other_columns = [
    "flow_duration",
    "header_length",
    "protocol_type",
    "duration",
    "rate",
    "srate",
    "drate",
    "fin_flag_number",
    "syn_flag_number",
]

In [None]:
# Null value analysis
print("Null Value Analysis for Other Columns")
print("=" * 50)

for col in other_columns:
    null_count = df[col].isnull().sum()
    null_ratio = df[col].isnull().mean()

    print(f"{col:20s} | nulls: {null_count:8d} | ratio: {null_ratio:.4f}")

In [None]:
# -----------------------------
# Calculating unique values
# -----------------------------

print("\nUnique Value Counts")
print("=" * 50)

other_columns_nunique = df[other_columns].nunique()
print(other_columns_nunique)

Declaration of Columns: Numeric Columns

In [None]:
# Columns with different names
different_columns = [
    "tot size",
    "magnitue",
    "iat",
]

In [None]:

# Renaming columns for consistency
df = df.rename(columns = {
    "tot size": "total_size",
    "magnitue": "magnitude",
    "iat": "inter_arrival_time",
})


In [None]:
# Null value analysis
print("Null Value Analysis for Different Columns")
print("=" * 50)

for col in different_columns:
    null_count = df[col].isnull().sum()
    null_ratio = df[col].isnull().mean()

    print(f"{col:20s} | nulls: {null_count:8d} | ratio: {null_ratio:.4f}")

Caclulation Uniques Values, important for identifying rows that don't contribute to the feature prediction/classification

In [None]:
# -----------------------------
# Calculating unique values
# -----------------------------
print("\nUnique Value Counts")
print("=" * 50)

different_columns_nunique = df[different_columns].nunique()
print(different_columns_nunique)

Data Cleaning

In [None]:
# Detecting rows with identical features, but different target feature (Label)

# Identify feature columns (exclude target)
feature_cols = df.columns.difference(['label'])

print("Detecting feature-identical rows with multiple target labels")
print("=" * 50)

# Retrieve all rows involved in label conflicts
print(df_conflicts.sort_values(feature_cols.tolist()))


# Detect feature-identical rows with multiple target labels
label_variation = (
    df
    .groupby(feature_cols)['label']
    .nunique()
    .reset_index(name='n_labels')
)

# Keep only conflicting feature patterns
conflicting_patterns = label_variation[label_variation['n_labels'] > 1]

# Retrieve all rows involved in label conflicts
df_conflicts = df.merge(
    conflicting_patterns[feature_cols],
    on=feature_cols,
    how='inner'
)

In [None]:

# Detect constant columns across the FULL dataset (excluding label)
feature_cols = df.columns.difference(['label'])

global_nunique = df[feature_cols].nunique()

constant_cols = global_nunique[global_nunique == 1].index.tolist()

print("Dropping constant columns:", constant_cols)

df.drop(columns=constant_cols, inplace=True)

Identify Features and Target Columns

In [None]:
encoder = LabelEncoder()
df['label_encoded'] = encoder.fit_transform(df['label'])

In [None]:

# Class to Attack Label Mapping 

id_to_attack = dict(enumerate(encoder.classes_))
attack_to_id = {label: idx for idx, label in id_to_attack.items()}

# Displaying the class ID to attack type mapping
print("\nClass ID â†’ Attack Type mapping:\n")
for k, v in id_to_attack.items():
    print(f"Class {k}: {v}")
 
# Saving the mapping to a CSV file
mapping_df = pd.DataFrame({
    "Class ID": list(id_to_attack.keys()),
    "Attack Type": list(id_to_attack.values())
})

# Saving the mapping to a CSV file
mapping_df.to_csv("class_label_mapping.csv", index=False)
print("\n[âœ“] Saved class_label_mapping.csv")


In [None]:
# Handling missing values
target = 'label_encoded'
features = [c for c in df.columns if c not in ['label', 'label_encoded']]

# Fill missing values in
numeric_features = df[features].select_dtypes(include=np.number).columns
df[numeric_features] = df[numeric_features].fillna(df[numeric_features].mean())


In [None]:
# Creation of X and Y values
X = df[features].values
y = df[target].values

print(X.shape, y.shape)

Train / Test Split

In [None]:

# Splitting the dataset into training and testing sets with stratification
X_train_np, X_test_np, y_train_np, y_test_np = train_test_split(
    X, y,
    test_size=0.2,
    stratify=y,
    random_state=42
)

In [None]:
# Standardizing the features
scaler = StandardScaler()
X_train_np = scaler.fit_transform(X_train_np)
X_test_np  = scaler.transform(X_test_np)

In [None]:
# Converting to PyTorch tensors
X_train = torch.tensor(X_train_np, dtype=torch.float32)
X_test = torch.tensor(X_test_np, dtype=torch.float32)
y_train = torch.tensor(y_train_np, dtype=torch.long)
y_test = torch.tensor(y_test_np, dtype=torch.long)

# Displaying the shapes of the training and testing tensors
print(X_train.shape, X_test.shape)



In [None]:
# Creating DataLoader based on batch size
batch_size = 512
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(TensorDataset(X_test, y_test), batch_size=batch_size)

In [None]:
# Setting device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Model Definition

In [None]:
# Model instantiation
num_classes = len(df[target].unique())
num_features = X_train.shape[1]


# Convolutional Neural Network model
model = Convolutional_Neural_Network(num_classes, num_features).to(device)
print(model)

Training

In [None]:
# Loss Function Crossentropy Loss
loss_function = nn.CrossEntropyLoss()

# Optimizer hyperparameters
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-3,
    weight_decay=1e-4  # optional but recommended
)

# Learning rate scheduler to reduce LR on plateau
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=3,
    min_lr=1e-5,
    verbose=True
)
 
# Training hyperparameters
num_epochs = 50

# Early Stopping 
best_val_loss = float("inf")
patience = 5          # stop after 5 bad epochs
patience_ctr = 0
min_delta = 1e-4
best_model_path = "best_model.pt"


In [None]:
# Training loop with validation and learning rate scheduling
for epoch in range(num_epochs):
    # -------- Train --------
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    # For looping through training batches
    for batch_x, batch_y in train_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)

        # Training using backpropagation
        optimizer.zero_grad()
        logits = model(batch_x)
        loss = loss_function(logits, batch_y)
        loss.backward()
        optimizer.step()

        # Accumulate loss (sum over samples)
        total_loss += loss.item() * batch_x.size(0)

        # Accumulate accuracy
        preds = torch.argmax(logits, dim=1)
        correct += (preds == batch_y).sum().item()
        total += batch_y.size(0)

    # Calculating average training loss/accuracy
    train_loss = total_loss / len(train_loader.dataset)
    train_acc = correct / total if total > 0 else 0.0

    # -------- Validation (using test_loader as validation set) --------
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for bx, by in test_loader:
            bx, by = bx.to(device), by.to(device)
            logits = model(bx)
            val_loss += loss_function(logits, by).item() * bx.size(0)

            preds = torch.argmax(logits, dim=1)
            val_correct += (preds == by).sum().item()
            val_total += by.size(0)

    # Calculating average validation loss/accuracy
    val_loss /= len(test_loader.dataset)
    val_acc = val_correct / val_total if val_total > 0 else 0.0

    # Scheduler step
    scheduler.step(val_loss)

    # Printing Training progress
    print(
        f"Epoch {epoch+1}/{num_epochs} | "
        f"Train Acc: {train_acc:.4f} | "
        f"Train Loss: {train_loss:.4f} | "
        f"Val Acc: {val_acc:.4f} | "
        f"Val Loss: {val_loss:.4f} | "
        f"LR: {optimizer.param_groups[0]['lr']:.2e}"
    )

Evaluation

In [None]:
# Matplotlib global settings
mpl.rcParams.update({
    # Resolution
    "figure.dpi": 150,
    "savefig.dpi": 300,

    # Font (IEEE/Springer safe)
    "font.family": "serif",
    "font.serif": ["Times New Roman", "Times", "DejaVu Serif"],

    # Font sizes
    "font.size": 10,
    "axes.titlesize": 12,
    "axes.titleweight": "bold",
    "axes.labelsize": 11,
    "axes.labelweight": "bold",

    "xtick.labelsize": 9,
    "ytick.labelsize": 9,

    "legend.fontsize": 9,
    "legend.frameon": False,

    # Clean axes
    "axes.spines.top": False,
    "axes.spines.right": False,
    "axes.grid": False,

    # Line widths
    "lines.linewidth": 2,
})

In [None]:

# Evaluation metrics with attack names 



def evaluate_with_names(all_true, all_preds, all_probs, id_to_attack):
    num_classes = len(id_to_attack)
    class_names = [id_to_attack[i] for i in range(num_classes)]

    print("=" * 80)
    print("MODEL EVALUATION METRICS")
    print("=" * 80)

    print("\nClassification Report:")
    print(classification_report(
        all_true,
        all_preds,
        target_names=class_names,
        zero_division=0
    ))

    macro_f1 = f1_score(all_true, all_preds, average='macro', zero_division=0)
    weighted_f1 = f1_score(all_true, all_preds, average='weighted', zero_division=0)

    print("\nF1 Scores:")
    print(f"  Macro F1 Score:     {macro_f1:.4f}")
    print(f"  Weighted F1 Score:  {weighted_f1:.4f}")

    macro_recall = recall_score(all_true, all_preds, average='macro', zero_division=0)
    weighted_recall = recall_score(all_true, all_preds, average='weighted', zero_division=0)
    per_class_recall = recall_score(all_true, all_preds, average=None, zero_division=0)

    print("\nRecall Scores:")
    print(f"  Macro Recall:     {macro_recall:.4f}")
    print(f"  Weighted Recall:  {weighted_recall:.4f}")
    print("  Per-class Recall:")
    for i, r in enumerate(per_class_recall):
        print(f"    {id_to_attack[i]}: {r:.4f}")

    cm = confusion_matrix(all_true, all_preds)
    accuracy = accuracy_score(all_true, all_preds)

    print(f"\nOverall Accuracy: {accuracy:.4f}")
    print("=" * 80)

    return {
        'macro_f1': macro_f1,
        'weighted_f1': weighted_f1,
        'macro_recall': macro_recall,
        'weighted_recall': weighted_recall,
        'accuracy': accuracy,
        'confusion_matrix': cm
    }


Data Analysis / Visualization

In [None]:

# ROC-AUC curves and metrics heatmap with attack names




def plot_roc_and_heatmap(all_true, all_probs, id_to_attack,
                          max_classes=5,
                          save_path='roc_and_metrics_heatmap.png'):
    num_classes = len(id_to_attack)
    class_names = [id_to_attack[i] for i in range(num_classes)]

    fig, axes = plt.subplots(1, 2, figsize=(16, 6))

    # ROC curves (OvR)
    all_true_bin = label_binarize(all_true, classes=range(num_classes))
    colors = plt.cm.Set3(np.linspace(0, 1, num_classes))

    for i in range(min(num_classes, max_classes)):
        fpr, tpr, _ = roc_curve(all_true_bin[:, i], all_probs[:, i])
        roc_auc = auc(fpr, tpr)
        axes[0].plot(fpr, tpr, lw=2,
                     label=f'{class_names[i]} (AUC={roc_auc:.3f})')

    axes[0].plot([0, 1], [0, 1], 'k--', lw=2)
    axes[0].set_xlabel('False Positive Rate')
    axes[0].set_ylabel('True Positive Rate')
    axes[0].set_title('ROC Curves (One-vs-Rest)')
    axes[0].legend(fontsize=8)
    axes[0].grid(True, alpha=0.3)

    # Metrics heatmap
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_true, np.argmax(all_probs, axis=1), average=None, zero_division=0
    )
    metrics_data = np.vstack([precision, recall, f1]).T

    im = axes[1].imshow(metrics_data[:max_classes], cmap='RdYlGn', vmin=0, vmax=1)
    axes[1].set_xticks([0, 1, 2])
    axes[1].set_xticklabels(['Precision', 'Recall', 'F1'])
    axes[1].set_yticks(range(min(num_classes, max_classes)))
    axes[1].set_yticklabels(class_names[:max_classes])
    axes[1].set_title('Per-Class Metrics Heatmap')

    plt.colorbar(im, ax=axes[1])
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()


In [None]:
# Gathering all predictions and probabilities
# Additional metrics visualization


fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 1. F1 Scores Comparison
ax1 = axes[0]
metrics = ['Macro F1', 'Weighted F1']
scores = [macro_f1, weighted_f1]
colors_metrics = ['#FF6B6B', '#4ECDC4']
bars = ax1.bar(metrics, scores, color=colors_metrics, alpha=0.8, edgecolor='black', linewidth=2)
ax1.set_ylabel('F1 Score', fontsize=12, fontweight='bold')
ax1.set_title('F1 Scores Comparison', fontsize=14, fontweight='bold')
ax1.set_ylim([0, 1])
ax1.grid(True, alpha=0.3, axis='y')

for bar, score in zip(bars, scores):
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height,
             f'{score:.4f}', ha='center', va='bottom', fontsize=12, fontweight='bold')

# 2. Macro Metrics Summary
ax2 = axes[1]
macro_precision = precision_score(all_true, all_preds, average='macro', zero_division=0)
summary_metrics = ['Accuracy', 'Precision', 'Recall', 'F1']
summary_values = [accuracy, macro_precision, macro_recall, macro_f1]
colors_summary = ['#95E1D3', '#F38181', '#AA96DA', '#FCBAD3']

bars = ax2.barh(summary_metrics, summary_values, color=colors_summary, alpha=0.8, edgecolor='black', linewidth=2)
ax2.set_xlabel('Score', fontsize=12, fontweight='bold')
ax2.set_title('Macro-Averaged Metrics Summary', fontsize=14, fontweight='bold')
ax2.set_xlim([0, 1]) 
ax2.grid(True, alpha=0.3, axis='x')

for bar, value in zip(bars, summary_values):
    width = bar.get_width()
    ax2.text(width, bar.get_y() + bar.get_height()/2.,
             f'{value:.4f}', ha='left', va='center', fontsize=11, fontweight='bold', 
             bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))

plt.tight_layout()
plt.savefig('metrics_summary.png', dpi=300, bbox_inches='tight')
plt.show()

print(" Metrics summary visualization saved as 'metrics_summary.png'")


In [None]:

# ROC-AUC Curves (One-vs-Rest for multi-class)


fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# 1. ROC Curve (One-vs-Rest)
ax1 = axes[0]
if num_classes > 2:
    # Multi-class: use label binarization
    all_true_bin = label_binarize(all_true, classes=range(num_classes))
    
    colors = plt.cm.Set3(np.linspace(0, 2, num_classes))
    auc_scores = []
    
    for i in range(min(num_classes, 5)):  # Limit to 5 classes for clarity
        fpr, tpr, _ = roc_curve(all_true_bin[:, i], all_probs[:, i])
        roc_auc = auc(fpr, tpr)
        auc_scores.append(roc_auc)
        ax1.plot(fpr, tpr, color=colors[i], lw=2, label=f'Class {i} (AUC = {roc_auc:.3f})')
    
    # Micro-average
    fpr, tpr, _ = roc_curve(all_true_bin.ravel(), all_probs.ravel())
    roc_auc_micro = auc(fpr, tpr)
    ax1.plot(fpr, tpr, color='deeppink', lw=4, linestyle=':', label=f'Micro-average (AUC = {roc_auc_micro:.3f})')
    
    ax1.plot([0, 1], [0, 1], 'k--', lw=2, label='Random Classifier')
    ax1.set_xlim([0.0, 1.0])
    ax1.set_ylim([0.0, 1.05])
    ax1.set_xlabel('False Positive Rate', fontsize=11, fontweight='bold')
    ax1.set_ylabel('True Positive Rate', fontsize=11, fontweight='bold')
    ax1.set_title('ROC Curves (One-vs-Rest) - Top 5 Classes', fontsize=13, fontweight='bold')
    ax1.legend(loc="lower right", fontsize=9)
    ax1.grid(True, alpha=0.3)
else:
    # Binary classification
    fpr, tpr, _ = roc_curve(all_true, all_probs[:, 1])
    roc_auc = auc(fpr, tpr)
    ax1.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
    ax1.plot([0, 1], [0, 1], 'k--', lw=2, label='Random Classifier')
    ax1.set_xlim([0.0, 1.0])
    ax1.set_ylim([0.0, 1.05])
    ax1.set_xlabel('False Positive Rate', fontsize=11, fontweight='bold')
    ax1.set_ylabel('True Positive Rate', fontsize=11, fontweight='bold')
    ax1.set_title('ROC Curve', fontsize=13, fontweight='bold')
    ax1.legend(loc="lower right", fontsize=10)
    ax1.grid(True, alpha=0.3)

# 2. Per-Class Metrics Heatmap
ax2 = axes[1]
from sklearn.metrics import precision_recall_fscore_support
precision, recall, f1, support = precision_recall_fscore_support(all_true, all_preds, 
                                                                   average=None, zero_division=0)

metrics_data = np.array([precision, recall, f1]).T
im = ax2.imshow(metrics_data, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1)

# Set ticks and labels
ax2.set_xticks([0, 1, 2])
ax2.set_xticklabels(['Precision', 'Recall', 'F1'], fontsize=11, fontweight='bold')
ax2.set_yticks(range(min(len(precision), 10)))
ax2.set_yticklabels([f'Class {i}' for i in range(min(len(precision), 10))], fontsize=10)
ax2.set_title('Per-Class Metrics Heatmap (Top 10 Classes)', fontsize=13, fontweight='bold')

# Add text annotations
for i in range(min(len(precision), 10)):
    for j in range(3):
        text = ax2.text(j, i, f'{metrics_data[i, j]:.2f}',
                       ha="center", va="center", color="black", fontsize=9, fontweight='bold')

cbar = plt.colorbar(im, ax=ax2)
cbar.set_label('Score', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.savefig('roc_and_metrics_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()

print(" ROC and metrics heatmap visualization saved as 'roc_and_metrics_heatmap.png'")
