# An Improved LPI Radar Waveform Recognition Framework With LDC-Unet and SSR-Loss

### 1 Setup

#### 1.1 Imports

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchinfo import summary
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import h5py
import numpy as np
from sklearn.preprocessing import LabelEncoder
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from tqdm import tqdm
from sklearn.metrics import (
    confusion_matrix,
    ConfusionMatrixDisplay,
    classification_report,
)
import cv2
import seaborn as sns
import os
import pickle
from scipy.signal import get_window, spectrogram
from scipy.fft import fft, fftshift, fftfreq
import cupy as cp

#### 1.2 Device Selection

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

#### 1.3 Data Input

In [None]:
def pre_processing(images):
    # Resize images to 64x64 using bilinear interpolation
    resized_images = np.array(
        [cv2.resize(img, (64, 64), interpolation=cv2.INTER_CUBIC) for img in images]
    )

    return resized_images

In [None]:
def load_spectrogram_h5s(root_folder, mod_types):
    """
    Loads all .h5 spectrogram files from the specified modulation types
    and returns a dict structured as:
      {
          "lfm_up": {
              0: <np.ndarray>,
              1: <np.ndarray>,
              ...
          },
          ...
      }

    Parameters:
    - root_folder (str): Path to the top-level "spectrograms" directory.
    - mod_types (list): List of modulation families to include, e.g. ['FM', 'PM']

    Returns:
    - Dictionary containing all spectrograms, indexed by modulation name and integer index.
    """
    spectrogram_dict = {}

    for mod_type in mod_types:
        mod_path = os.path.join(root_folder, mod_type)
        if not os.path.exists(mod_path):
            print(f"⚠️ Warning: {mod_path} does not exist. Skipping.")
            continue

        print(f"📂 Loading from {mod_type}...")
        files = [f for f in os.listdir(mod_path) if f.endswith(".h5")]

        for file in tqdm(files, desc=f"   {mod_type}", unit="file"):
            mod_name = file[:-3]  # Strip '.h5'
            file_path = os.path.join(mod_path, file)
            spectrogram_dict[mod_name] = {}

            try:
                with h5py.File(file_path, "r") as h5f:
                    for key in h5f.keys():
                        idx = int(key)  # Convert string index to int
                        spectrogram_dict[mod_name][idx] = np.array(h5f[key])
            except Exception as e:
                print(f"❌ Failed to load {file_path}: {e}")

    return spectrogram_dict


In [None]:
# Ensure the main directory exists
# Add in path to where the spectrograms are stored
data_path = "C:/Users/scocks/Documents/hehehehehehhe/images/"
os.makedirs(data_path, exist_ok=True)

img_res = 224
img_count = 1000
snr = 0


folder_name = f"General_Images_res_{img_res}_sz_{img_count}_SNR_{snr}"
folder_path = data_path+folder_name

modulation_types = [
    "FM",
    "PM",
    "HYBRID",
]

data = load_spectrogram_h5s(folder_path, modulation_types)

In [None]:
def convert_spectrogram_dict_to_xy(data_dict):
    """
    Converts a dictionary of spectrograms into (X, y) format for ML.

    Parameters:
    - data_dict: Output from load_spectrogram_h5s(), e.g.
        {
            "lfm_up": {0: np.array, 1: np.array, ...},
            "bpsk":   {0: np.array, ...},
            ...
        }

    Returns:
    - X: np.ndarray of shape (N, H, W, C)
    - y: np.ndarray of shape (N,) with string labels like 'lfm_up'
    """
    X_list = []
    y_list = []

    for label, spectros in data_dict.items():
        for idx in sorted(spectros.keys()):
            X_list.append(spectros[idx])
            y_list.append(label)

    X = np.array(X_list)
    y = np.array(y_list)

    return X, y


In [None]:
X, y = convert_spectrogram_dict_to_xy(data)

In [None]:
print(X.shape)
print(y.shape)

#### 1.4 Label Encoder

In [None]:
label_encoder = LabelEncoder()

#### 1.5 Data Loader

In [None]:
def prepare_dataloader(X, y, batch_size=32, shuffle=False, num_workers=2, device="cpu"):
    # Convert NumPy arrays to PyTorch tensors
    if isinstance(X, np.ndarray):
        X = torch.tensor(X, dtype=torch.float32)
    elif not isinstance(X, torch.Tensor):
        raise TypeError("Input X must be a NumPy array or PyTorch tensor")

    if isinstance(y, np.ndarray):
        y = torch.tensor(y, dtype=torch.long)
    elif not isinstance(y, torch.Tensor):
        raise TypeError("Labels y must be a NumPy array or PyTorch tensor")

    # Ensure X has four dimensions (N, C, H, W)
    if X.ndim == 3:  # If (N, H, W), add a channel dimension
        X = X.unsqueeze(1)  # (N, 1, H, W)
    elif X.ndim == 4 and X.shape[-1] in [1, 3]:  # (N, H, W, C) case
        X = X.permute(0, 3, 1, 2)  # Convert to (N, C, H, W)

    # Move data to the correct device
    X, y = X.to(device), y.to(device)

    # Create dataset and dataloader
    dataset = TensorDataset(X, y)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=(device == "cuda"),
    )

    return loader

### 2 Pre-processing

In [None]:
y_encoded = label_encoder.fit_transform(y)

In [None]:
X_pre_processed = pre_processing(X)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    X_pre_processed, y_encoded, test_size=0.2, stratify=y_encoded, random_state=42
)

In [None]:
unique_train, counts_train = np.unique(y_train, return_counts=True)
unique_test, counts_test = np.unique(y_test, return_counts=True)

print("Class Distribution in Training Set:")
for label, count in zip(unique_train, counts_train):
    print(f"Class {label}: {count} samples")

print("\nClass Distribution in Test Set:")
for label, count in zip(unique_test, counts_test):
    print(f"Class {label}: {count} samples")

#### 2.1 Pre-processing Display

In [None]:
# selected_display = np.random.randint(len(X_train))
selected_display = 0

In [None]:
print("Train Images shape:", X_pre_processed.shape)
print("Train Metadata shape:", y_encoded.shape)
print(f"{y[selected_display]} = {y_encoded[selected_display]}")

In [None]:
plt.imshow(X[selected_display])
plt.show()
print(X[selected_display].shape)

In [None]:
plt.imshow(X_pre_processed[selected_display], cmap="grey")
plt.show()
print(X_pre_processed[selected_display].shape)


### 3 Training

#### 3.1 Training Setup

In [None]:
class SSRLoss(nn.Module):
    def __init__(self, num_classes, feature_dim, lambda_reg=0.12):
        super(SSRLoss, self).__init__()
        self.lambda_reg = lambda_reg  # Weight factor for L1 loss
        self.num_classes = num_classes
        self.feature_dim = feature_dim
        self.centers = nn.Parameter(torch.randn(num_classes, feature_dim))  # Learnable class centers
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, features, labels, logits):
        """
        :param features: Feature vectors before FC layer (batch_size, feature_dim)
        :param labels: Class labels (batch_size,)
        :param logits: Output logits before softmax (batch_size, num_classes)
        """
        # Compute softmax loss
        loss_softmax = self.cross_entropy(logits, labels)

        # Get the class centers for each sample
        centers_batch = self.centers[labels]  # Shape: (batch_size, feature_dim)

        # Compute self-regularization loss (L1 distance)
        loss_reg = torch.mean(torch.abs(features - centers_batch))

        # Final SSR-Loss
        loss = loss_softmax + self.lambda_reg * loss_reg
        return loss

In [None]:
def train_model(
    model,
    train_loader,
    device,
    criterion,  # SSRLoss should be passed here
    optimizer,
    scheduler=None,  # 🔧 Optional scheduler added
    epochs=10,
    patience=3,
    min_delta=0.0,
):
    """
    Trains the model with SSR-Loss.

    :param model: The PyTorch model (should return features and logits)
    :param train_loader: DataLoader for training
    :param device: CUDA or CPU
    :param criterion: SSRLoss instance
    :param optimizer: PyTorch optimizer (Adam or SGD)
    :param scheduler: Learning rate scheduler (optional)
    :param epochs: Number of training epochs
    :param patience: Early stopping patience
    :param min_delta: Minimum loss improvement for early stopping
    :return: List of epoch losses
    """

    model.to(device)
    model.train()

    loss_history = []
    best_loss = float("inf")
    patience_counter = 0

    for epoch in range(epochs):
        total_loss = 0.0

        # Progress bar for visualization
        progress_bar = tqdm(
            train_loader,
            desc=f"Epoch {epoch+1}/{epochs}",
            leave=True,
            dynamic_ncols=True,
        )

        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass - model should return (features, logits)
            features, logits = model(inputs)

            # Compute SSR-Loss using features and logits
            loss = criterion(features, labels, logits)

            # Backpropagation
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Live loss display
            progress_bar.set_postfix({"Loss": f"{loss.item():.4f}"})

        # Average loss for the epoch
        avg_loss = total_loss / len(train_loader)
        loss_history.append(avg_loss)

        print(f"Epoch {epoch+1} average loss: {avg_loss:.4f}")

        model_file_name = f"model_latest"
        if epoch % 5:
            torch.save(model, model_file_name + ".pth")

        # 🔄 Scheduler step
        if scheduler:
            scheduler.step()

        # Early stopping check
        if avg_loss < best_loss - min_delta:
            best_loss = avg_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break

    return loss_history


##### Loss Curve

In [None]:
def plot_loss_curve(loss_history, title="Training Loss Over Epochs"):
    epochs = len(loss_history)

    plt.figure(figsize=(8, 5))
    plt.plot(range(1, epochs + 1), loss_history, marker="o", label="Training Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title(title)
    plt.legend()
    plt.grid()
    plt.show()

##### Conf Matirx

In [None]:
def display_confusion_matrix(
    model, data_loader, device, class_names=None, title="Confusion Matrix"
):
    """
    Generate and display a normalized confusion matrix for a trained model.
    
    Parameters:
        model (torch.nn.Module): Trained PyTorch model.
        data_loader (torch.utils.data.DataLoader): DataLoader for evaluation dataset.
        device (torch.device): Device to run evaluation on (CPU/GPU).
        class_names (list, optional): List of class names. If None, uses numeric indices.
        title (str): Title of the confusion matrix plot.
    """
    # Switch model to evaluation mode
    model.to(device)
    model.eval()

    all_preds = []
    all_labels = []

    # Disable gradient calculations for inference
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass: Ignore output_image, focus only on output_class
            _, output_class = model(inputs)

            # Get predicted class labels
            preds = torch.argmax(output_class, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Compute confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    num_classes = cm.shape[0]
    
    # Normalize confusion matrix to percentages
    cm_normalized = cm.astype(np.float32) / cm.sum(axis=1, keepdims=True) * 100

    # If class_names isn't provided, use numeric class indices
    if class_names is None:
        class_names = [str(i) for i in range(num_classes)]

    # Plotting the confusion matrix
    plt.figure(figsize=(max(10, num_classes * 0.8), max(8, num_classes * 0.6)))  # Dynamic size
    im = plt.imshow(cm_normalized, interpolation="nearest", cmap="Blues")
    plt.title(title, fontsize=14)
    plt.colorbar(im, label="Percentage")  # Add colorbar with label

    # Create tick marks for class labels
    tick_marks = np.arange(num_classes)
    plt.xticks(tick_marks, class_names, rotation=45, ha="right", va="top", fontsize=max(8, 12 - num_classes // 5))
    plt.yticks(tick_marks, class_names, fontsize=max(8, 12 - num_classes // 5))

    # Annotate the matrix cells with percentage values
    thresh = cm_normalized.max() / 2.0
    for i in range(num_classes):
        for j in range(num_classes):
            plt.text(
                j,
                i,
                f"{cm_normalized[i, j]:.1f}",
                ha="center",
                va="center",
                color="white" if cm_normalized[i, j] > thresh else "black",
                fontsize=max(8, 12 - num_classes // 5),
            )

    plt.ylabel("True Label", fontsize=12, labelpad=10)
    plt.xlabel("Predicted Label", fontsize=12, labelpad=10)
    
    # Adjust layout with extra bottom margin for rotated labels
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.2 + num_classes * 0.005)  # Dynamic bottom margin
    
    plt.show()

#### 3.2 Model

##### E1-5

In [None]:
class encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(encoder, self).__init__()

        self.relu = nn.ReLU()

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2,padding=1)

        self.conv2 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2,padding=1)
        self.conv3 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1,padding=1)

    def forward(self, x):

        x1 = self.conv1(x)
        x1 = self.relu(x1)

        x2 = self.conv2(x)
        x2 = self.relu(x2)
        x2 = self.conv3(x2)

        x = x1+x2

        return x

##### D1-3

In [None]:
class decoder(nn.Module):
    def __init__(self, c1, c2, c3, c4):
        super(decoder, self).__init__()

        self.relu = nn.ReLU()

        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv1 = nn.Conv2d(in_channels=c1, out_channels=64, kernel_size=3 , stride=1,padding=1)

        self.conv2 = nn.Conv2d(in_channels=c2, out_channels=64, kernel_size=3 , stride=1,padding=1)

        self.upsample1 = nn.Upsample(scale_factor=(2, 2), mode='bilinear', align_corners=True)
        self.conv3 = nn.Conv2d(in_channels=c3, out_channels=64, kernel_size=3 , stride=1,padding=1)

        self.upsample2 = nn.Upsample(scale_factor=(2, 2), mode='bilinear', align_corners=True)
        self.conv4 = nn.Conv2d(in_channels=c4, out_channels=64, kernel_size=3 , stride=1,padding=1)



    def forward(self, x1, x2, x3, x4):

        x1 = self.pool1(x1)
        x1 = self.conv1(x1)

        x2 = self.conv2(x2)

        x3 = self.upsample1(x3)
        x3 = self.conv3(x3)
        
        x4 = self.upsample2(x4)
        x4 = self.conv4(x4)

        x = torch.cat((x1,x2,x3,x4),dim=1)

        return x

##### D4

In [None]:
class decoder_4(nn.Module):
    def __init__(self, c1, c2, c3):
        super(decoder_4, self).__init__()

        self.relu = nn.ReLU()

        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv1 = nn.Conv2d(in_channels=c1, out_channels=64, kernel_size=3 , stride=1,padding=1)

        self.conv2 = nn.Conv2d(in_channels=c2, out_channels=64, kernel_size=3 , stride=1,padding=1)

        self.upsample1 = nn.Upsample(scale_factor=(2, 2), mode='bilinear', align_corners=True)
        self.conv3 = nn.Conv2d(in_channels=c3, out_channels=64, kernel_size=3 , stride=1,padding=1)


    def forward(self, x1, x2, x3):

        x1 = self.pool1(x1)
        x1 = self.conv1(x1)

        x2 = self.conv2(x2)

        x3 = self.upsample1(x3)
        x3 = self.conv3(x3)
    
        x = torch.cat((x1,x2,x3),dim=1)

        return x

##### LDC-Unet

In [None]:
class LDC_Unet(nn.Module):
    def __init__(self):
        super(LDC_Unet, self).__init__()

        self.e1 = encoder(in_channels=3, out_channels=64)
        self.e2 = encoder(in_channels=64, out_channels=128)
        self.e3 = encoder(in_channels=128, out_channels=256)
        self.e4 = encoder(in_channels=256, out_channels=256)
        self.e5 = encoder(in_channels=256, out_channels=512)

        self.d4 = decoder_4(c1=256 , c2=256, c3=512)
        self.d3 = decoder(c1=128 , c2=256, c3=256, c4=192)
        self.d2 = decoder(c1=64 , c2=128, c3=256, c4=256)
        self.d1 = decoder(c1=3 , c2=64, c3=128, c4=256)

        self.upsample = nn.Upsample(scale_factor=(2, 2), mode='bilinear', align_corners=True)
        self.conv1 = nn.Conv2d(in_channels=256, out_channels=64, kernel_size=3, stride=1, padding=1)  # Change padding to 0
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1)  # Increase padding to 2

        

    def forward(self, x):

        x0 = x

        x_e1 = self.e1(x0)

        x_e2 = self.e2(x_e1)
        x_e3 = self.e3(x_e2)
        x_e4 = self.e4(x_e3)
        x_e5 = self.e5(x_e4)

        x_d4 = self.d4(x_e3, x_e4, x_e5)
        x_d3 = self.d3(x_e2, x_e3, x_e4, x_d4)
        x_d2 = self.d2(x_e1, x_e2, x_e3, x_d3)
        x_d1 = self.d1(x0, x_e1, x_e2, x_d2)
        
        x = self.upsample(x_d1)
        x = self.conv1(x)
        x = self.conv2(x)

        return x

##### DCNN

In [None]:
class DCNN(nn.Module):
    def __init__(self, num_classes):
        super(DCNN, self).__init__()
        
        # Load pre-trained VGG19
        self.vgg19 = models.vgg19(pretrained=True)

        # Remove fully connected layers
        self.features = self.vgg19.features  # Keep convolutional layers

        # Custom classification block
        self.classification_block = nn.Sequential(
            nn.AdaptiveAvgPool2d((7, 7)),  # Ensure compatibility
            nn.Flatten(),
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 256)  # Features before final classification
        )

        self.relu = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)
        self.dense_last = nn.Linear(256, num_classes)  # Final classification layer

    def forward(self, x):
        x = self.features(x)  # Extract CNN features
        features = self.classification_block(x)  # Intermediate feature representation
        features = self.relu(features)
        features = self.dropout1(features)

        logits = self.dense_last(features)  # Final classification output

        return features, logits  # Return both features and logits


##### Main Model

In [None]:
class MainModel(nn.Module):
    def __init__(self, num_classes):
        super(MainModel, self).__init__()

        self.ldc_unet = LDC_Unet()

        self.dcnn = DCNN(num_classes)
        

    def forward(self, x):

        x = self.ldc_unet(x)

        f, l = self.dcnn(x)

        return f, l

##### Model Summary

In [None]:
# model = MainModel(6)
# summary(model , input_size=(1, 3, 128, 128))

#### 3.3 Actual Training

In [None]:
# Prepare DataLoaders
train_loader = prepare_dataloader(
    X_train,
    y_train,
    batch_size=32,
    shuffle=True,
)

In [None]:
num_classes = len(np.unique(y_encoded))

model = MainModel(num_classes=num_classes).to(device)


# Define loss function and optimizer
criterion = SSRLoss(num_classes=num_classes, feature_dim=256, lambda_reg=0.12).to(device)

In [None]:
learning_rate = 1e-4
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=45, gamma=0.1)

In [None]:
# Train the model
epoch_count = 1
loss_history = train_model(
    model=model,
    train_loader=train_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    scheduler=scheduler,
    epochs=epoch_count,
    patience=50,
)

In [None]:
plot_loss_curve(loss_history)

In [None]:
display_confusion_matrix(model, train_loader, device)

#### 3.4 Save Model to File

In [None]:
mds = modulation_types[0]
if len(modulation_types) == 3:
    mds = "ALL"

model_file_name = f"LDC_Unet_model_e{epoch_count}_lr{learning_rate}_snr_{snr}_mds_{mds}"

In [None]:
torch.save(model, model_file_name + ".pth")

### 4 Testing

In [None]:
def evaluate_model(model, test_loader, label_encoder, device):
    """
    Evaluates the trained model and displays accuracy, confusion matrix, F1-score,
    and one output image per class.

    Args:
        model: Trained PyTorch model.
        test_loader: DataLoader for test set.
        label_encoder: Label encoder to decode class names.
        device: 'cuda' or 'cpu' where evaluation happens.
    """
    model.to(device)  # Ensure model is on correct device
    model.eval()  # Set to evaluation mode

    y_true = []
    y_pred = []
    correct = 0
    total = 0

    # Dictionary to store one image per class
    class_images = {}

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Model returns (output_image, output_class)
            output_image, output_class = model(inputs)

            # Get predicted class (argmax over logits)
            preds = torch.argmax(output_class, dim=1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

            y_true.extend(labels.cpu().tolist())  # Move to CPU for metrics
            y_pred.extend(preds.cpu().tolist())

            # Store one output image per class
            for img, pred_class in zip(output_image, preds):
                pred_class = pred_class.item()
                if pred_class not in class_images:
                    class_images[pred_class] = img.cpu()

    # Compute Accuracy
    accuracy = 100 * correct / total
    print(f"Accuracy: {accuracy:.2f}%")

    # Compute & Display Confusion Matrix
    class_names = label_encoder.classes_  # Decode label names
    cm = confusion_matrix(y_true, y_pred)
    # Normalize confusion matrix to percentages
    cm_normalized = cm.astype(np.float32) / cm.sum(axis=1, keepdims=True) * 100
    num_classes = len(class_names)

    # Plot confusion matrix
    fig, ax = plt.subplots(figsize=(max(10, num_classes * 0.8), max(8, num_classes * 0.6)))  # Dynamic size
    disp = ConfusionMatrixDisplay(confusion_matrix=cm_normalized, display_labels=class_names)
    disp.plot(cmap="Blues", values_format=".1f", ax=ax)  # Use 1 decimal place for percentages

    # Adjust x-axis label alignment and font sizes
    ax.set_xticklabels(class_names, rotation=45, ha="right", va="top", fontsize=max(8, 12 - num_classes // 5))
    ax.set_yticklabels(class_names, rotation=0, fontsize=max(8, 12 - num_classes // 5))
    ax.set_xlabel("Predicted Label", fontsize=12, labelpad=10)
    ax.set_ylabel("True Label", fontsize=12, labelpad=10)
    ax.set_title("Confusion Matrix (Percentage)", fontsize=14)

    # Adjust layout with extra bottom margin for rotated labels
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.2 + num_classes * 0.005)  # Dynamic bottom margin

    plt.show()

    # Print Classification Report (Precision, Recall, F1-score)
    print("\nClassification Report:")
    print(classification_report(y_true, y_pred, target_names=class_names))

In [None]:
# Prepare DataLoaders
test_loader = prepare_dataloader(
    X_test,
    y_test,
    batch_size=32,
)

In [None]:
# Evaluate the model
evaluate_model(model, test_loader, label_encoder, device)