# Improved Radar Signal Recognition by Combining ResNet with Transformer Learning

### 1 Setup

#### 1.1 Imports

In [None]:
import torch
import torch.nn as nn
from torchinfo import summary
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import h5py
import numpy as np
from sklearn.preprocessing import LabelEncoder
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
import cv2
import seaborn as sns
import os

#### 1.2 Device Selection

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

#### 1.3 Data Input

In [None]:
def load_spectrogram_h5s(root_folder, mod_types):
    """
    Loads all .h5 spectrogram files from the specified modulation types
    and returns a dict structured as:
      {
          "lfm_up": {
              0: <np.ndarray>,
              1: <np.ndarray>,
              ...
          },
          ...
      }

    Parameters:
    - root_folder (str): Path to the top-level "spectrograms" directory.
    - mod_types (list): List of modulation families to include, e.g. ['FM', 'PM']

    Returns:
    - Dictionary containing all spectrograms, indexed by modulation name and integer index.
    """
    spectrogram_dict = {}

    for mod_type in mod_types:
        mod_path = os.path.join(root_folder, mod_type)
        if not os.path.exists(mod_path):
            print(f"⚠️ Warning: {mod_path} does not exist. Skipping.")
            continue

        print(f"📂 Loading from {mod_type}...")
        files = [f for f in os.listdir(mod_path) if f.endswith(".h5")]

        for file in tqdm(files, desc=f"   {mod_type}", unit="file"):
            mod_name = file[:-3]  # Strip '.h5'
            file_path = os.path.join(mod_path, file)
            spectrogram_dict[mod_name] = {}

            try:
                with h5py.File(file_path, "r") as h5f:
                    for key in h5f.keys():
                        idx = int(key)  # Convert string index to int
                        spectrogram_dict[mod_name][idx] = np.array(h5f[key])
            except Exception as e:
                print(f"❌ Failed to load {file_path}: {e}")

    return spectrogram_dict


In [None]:
# Ensure the main directory exists
data_path = "C:/Users/scocks/Documents/hehehehehehhe/images/"
os.makedirs(data_path, exist_ok=True)

img_res = 224
img_count = 1000
snr = 0


folder_name = f"General_Images_res_{img_res}_sz_{img_count}_SNR_{snr}"
folder_path = data_path+folder_name

modulation_types = [
    "FM",
    "PM",
    "HYBRID",
]

data = load_spectrogram_h5s(folder_path, modulation_types)

##### 1.3.6 Conversion

In [None]:
def convert_spectrogram_dict_to_xy(data_dict):
    """
    Converts a dictionary of spectrograms into (X, y) format for ML.

    Parameters:
    - data_dict: Output from load_spectrogram_h5s(), e.g.
        {
            "lfm_up": {0: np.array, 1: np.array, ...},
            "bpsk":   {0: np.array, ...},
            ...
        }

    Returns:
    - X: np.ndarray of shape (N, H, W, C)
    - y: np.ndarray of shape (N,) with string labels like 'lfm_up'
    """
    X_list = []
    y_list = []

    for label, spectros in data_dict.items():
        for idx in sorted(spectros.keys()):
            X_list.append(spectros[idx])
            y_list.append(label)

    X = np.array(X_list)
    y = np.array(y_list)

    return X, y


In [None]:
X, y = convert_spectrogram_dict_to_xy(data)

In [None]:
print(X.shape)
print(y.shape)

#### 1.4 Label Encoder

In [None]:
label_encoder = LabelEncoder()

#### 1.5 Data Loader

In [None]:
def prepare_dataloader(X, y, batch_size=32, shuffle=False, num_workers=2, device="cpu"):
    # Convert NumPy arrays to PyTorch tensors
    if isinstance(X, np.ndarray):
        X = torch.tensor(X, dtype=torch.float32)
    elif not isinstance(X, torch.Tensor):
        raise TypeError("Input X must be a NumPy array or PyTorch tensor")

    if isinstance(y, np.ndarray):
        y = torch.tensor(y, dtype=torch.long)
    elif not isinstance(y, torch.Tensor):
        raise TypeError("Labels y must be a NumPy array or PyTorch tensor")

    # Ensure X has four dimensions (N, C, H, W)
    if X.ndim == 3:  # If (N, H, W), add a channel dimension
        X = X.unsqueeze(1)  # (N, 1, H, W)
    elif X.ndim == 4 and X.shape[-1] in [1, 3]:  # (N, H, W, C) case
        X = X.permute(0, 3, 1, 2)  # Convert to (N, C, H, W)

    # Move data to the correct device
    X, y = X.to(device), y.to(device)

    # Create dataset and dataloader
    dataset = TensorDataset(X, y)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=(device=="cuda"))

    return loader


#### 1.6 Data File/Folder Paths

In [None]:
# DATA_FOLDER = "../Data/"

In [None]:
# dataset_file = DATA_FOLDER + "images_num1000_snr0dB_sz500.h5"

In [None]:
# X,y = read_images_and_metadata(dataset_file)

### 2 Pre-processing

In [None]:
def pre_processing(images):
    # Greyscale conversion through maths
    gray_images = np.dot(images[..., :3], [0.2989, 0.5870, 0.1140]).astype(np.uint8)

    # Binarisation through threshold
    binary_images = (gray_images >= 128).astype(np.uint8)
    
    return binary_images

In [None]:
y_encoded = label_encoder.fit_transform(y)

In [None]:
X_pre_processed = pre_processing(X)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X_pre_processed, y_encoded, test_size=0.2, stratify=y_encoded,random_state=42)

In [None]:
unique_train, counts_train = np.unique(y_train, return_counts=True)
unique_test, counts_test = np.unique(y_test, return_counts=True)

print("Class Distribution in Training Set:")
for label, count in zip(unique_train, counts_train):
    print(f"Class {label}: {count} samples")

print("\nClass Distribution in Test Set:")
for label, count in zip(unique_test, counts_test):
    print(f"Class {label}: {count} samples")

#### 2.1 Pre-processing Display

In [None]:
# selected_display = np.random.randint(len(X_train))
selected_display = 0

In [None]:
print("Train Images shape:", X_pre_processed.shape)
print("Train Metadata shape:", y_encoded.shape)
print(f"{y[selected_display]} = {y_encoded[selected_display]}")

In [None]:
plt.imshow(X[selected_display])
plt.show()

In [None]:
plt.imshow(X_pre_processed[selected_display], cmap="grey")
plt.show()
print(X_pre_processed[selected_display].shape)


### 3 Training

#### 3.1 Training Setup

In [None]:
def train_model_separate_optim(
    model, train_loader, scr_loss_fn, model_optimizer, center_optimizer, 
    device, epochs=10, patience=3, min_delta=0.0
):
    """
    Trains the model with SCR-loss using separate optimizers for model and center-loss parameters.
    Includes Jupyter-compatible progress bar and early stopping.

    Parameters:
        model (torch.nn.Module): The model to train.
        train_loader (DataLoader): DataLoader for training data.
        scr_loss_fn (callable): SCR loss function.
        model_optimizer (torch.optim.Optimizer): Optimizer for model parameters.
        center_optimizer (torch.optim.Optimizer): Optimizer for center-loss parameters.
        device (torch.device): Device to run the model on.
        epochs (int): Maximum number of epochs.
        patience (int): Number of epochs with no improvement before early stopping.
        min_delta (float): Minimum loss improvement to reset patience.
    
    Returns:
        loss_history (list): A list containing the average loss per epoch.
    """
    model.to(device)
    model.train()
    loss_history = []

    best_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(epochs):
        total_loss = 0.0

        # Use tqdm for Jupyter compatibility
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=True, dynamic_ncols=True)

        for inputs, labels in progress_bar:  
            inputs = inputs.to(device)
            labels = labels.to(device)

            # ======== Step 1: Train Model (Cross-Entropy + Center Loss) ========
            model_optimizer.zero_grad()
            
            # Forward pass (get logits & feature embeddings)
            logits, features = model(inputs)

            # Compute SCR loss (returns total_loss, cross-entropy loss, center-loss)
            total_loss, loss_ce, loss_c = scr_loss_fn(logits, features, labels)

            # Backpropagate for model parameters (CE + Center loss together)
            total_loss.backward()
            model_optimizer.step()  # Update model weights

            # ======== Step 2: Train Center Loss Separately ========
            center_optimizer.zero_grad()
            
            # Compute center-loss again (this ensures it uses up-to-date embeddings)
            _, _, loss_c = scr_loss_fn(logits.detach(), features.detach(), labels)

            # Backpropagate **only center-loss for center parameters**
            loss_c.backward()
            center_optimizer.step()  # Update center vectors separately

            total_loss += total_loss.item()

            # Live loss update in Jupyter
            progress_bar.set_postfix({'Total': f"{total_loss.item():.4f}", 'CE': f"{loss_ce.item():.4f}", 'Center': f"{loss_c.item():.4f}"})

        # Average training loss for the epoch
        avg_loss = total_loss / len(train_loader)
        loss_history.append(avg_loss)

        print(f"Epoch {epoch+1} average loss: {avg_loss:.4f}")

        # Early stopping check
        if avg_loss < best_loss - min_delta:
            best_loss = avg_loss
            patience_counter = 0  # Reset patience counter on improvement
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break  # Exit training loop

    return loss_history

In [None]:
def plot_loss_curve(loss_history, title="Training Loss Over Epochs"):
    """
    Plots the loss values from `loss_history` against the corresponding epochs.

    Args:
        loss_history (list): A list of loss values (can be torch tensors, including CUDA).
        title (str): The title of the plot.
    """
    # Safely convert all elements to CPU floats
    loss_values = [loss.detach().cpu().item() if hasattr(loss, "detach") else float(loss) for loss in loss_history]

    epochs = len(loss_values)

    plt.figure(figsize=(8, 5))
    plt.plot(range(1, epochs + 1), loss_values, marker="o", label="Training Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title(title)
    plt.legend()
    plt.grid()
    plt.show()


In [None]:
def display_confusion_matrix(
    model, data_loader, device, class_names=None, title="Confusion Matrix"
):
    """
    Generate and display a normalized confusion matrix for a trained model.
    
    Parameters:
        model (torch.nn.Module): Trained PyTorch model.
        data_loader (torch.utils.data.DataLoader): DataLoader for evaluation dataset.
        device (torch.device): Device to run evaluation on (CPU/GPU).
        class_names (list, optional): List of class names. If None, uses numeric indices.
        title (str): Title of the confusion matrix plot.
    """
    # Switch to evaluation mode
    model.to(device)  # Ensure model is on correct device
    model.eval()

    # Lists to store all predictions and labels
    all_preds = []
    all_labels = []

    # Disable gradient calculations for inference
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            logits, _ = model(inputs)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Compute confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    num_classes = cm.shape[0]
    
    # Normalize confusion matrix to percentages
    cm_normalized = cm.astype(np.float32) / cm.sum(axis=1, keepdims=True) * 100

    # If class_names isn't given, use numeric class indices
    if class_names is None:
        class_names = [str(i) for i in range(num_classes)]

    # Plotting
    plt.figure(figsize=(max(10, num_classes * 0.8), max(8, num_classes * 0.6)))  # Dynamic size
    im = plt.imshow(cm_normalized, interpolation='nearest', cmap='Blues')
    plt.title(title, fontsize=14)
    plt.colorbar(im, label="Percentage")  # Add colorbar with label

    # Create tick marks
    tick_marks = np.arange(num_classes)
    plt.xticks(tick_marks, class_names, rotation=45, ha="right", va="top", fontsize=max(8, 12 - num_classes // 5))
    plt.yticks(tick_marks, class_names, fontsize=max(8, 12 - num_classes // 5))

    # Annotate the matrix cells with percentage values
    thresh = cm_normalized.max() / 2.0
    for i in range(num_classes):
        for j in range(num_classes):
            plt.text(
                j, i, f"{cm_normalized[i, j]:.1f}",
                ha="center", va="center",
                color="white" if cm_normalized[i, j] > thresh else "black",
                fontsize=max(8, 12 - num_classes // 5)
            )

    plt.ylabel("True Label", fontsize=12, labelpad=10)
    plt.xlabel("Predicted Label", fontsize=12, labelpad=10)
    
    # Adjust layout with extra bottom margin for rotated labels
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.2 + num_classes * 0.005)  # Dynamic bottom margin
    
    plt.show()

#### 3.2 Model

##### BTNK1

In [None]:
class BTNK1(nn.Module):
    def __init__(self, C, W, C1, S):
        super(BTNK1, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=C, out_channels=C1, kernel_size=(1,1),stride=S)
        self.bn1 = nn.BatchNorm2d(C1)

        self.conv2 = nn.Conv2d(in_channels=C1, out_channels=C1, kernel_size=(3,3),stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(C1)

        self.conv3 = nn.Conv2d(in_channels=C1, out_channels=C1*4, kernel_size=(1,1), stride=1)
        self.bn3 = nn.BatchNorm2d(C1*4)

        self.conv4 = nn.Conv2d(in_channels=C, out_channels=C1*4, kernel_size=(1,1), stride=S)
        self.bn4 = nn.BatchNorm2d(C1*4)

        self.relu = nn.ReLU()

    
    def forward(self, x):
       

        # Right
        shortcut = self.conv4(x)
        shortcut = self.bn4(shortcut)

        # Left
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)


        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)


        x = self.conv3(x)
        x = self.bn3(x)


        x += shortcut

        x = self.relu(x)

        return x

##### BTNK2

In [None]:
class BTNK2(nn.Module):
    def __init__(self, C, W):
        super(BTNK2, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=C, out_channels=int(C/4), kernel_size=1, stride=1)
        self.bn1 = nn.BatchNorm2d(int(C/4))

        self.conv2 = nn.Conv2d(in_channels=int(C/4), out_channels=int(C/4), kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(int(C/4))

        self.conv3 = nn.Conv2d(in_channels=int(C/4), out_channels=C, kernel_size=1, stride=1)
        self.bn3 = nn.BatchNorm2d(C)

        self.relu = nn.ReLU()

    
    def forward(self, x):
       

        shortcut = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)


        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)



        x = self.conv3(x)
        x = self.bn3(x)


        x += shortcut

        x = self.relu(x)

        return x

##### MHSA

In [None]:
class MHSA2D(nn.Module):
    """
    Minimal 2D multi-head self-attention.
    """
    def __init__(self, embed_dim, num_heads=4, dropout=0.0):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ln = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x shape: (N, C, H, W)
        N, C, H, W = x.size()

        # Flatten spatial dimensions -> (N, H*W, C)
        x_flat = x.permute(0, 2, 3, 1).reshape(N, H*W, C)

        attn_out, _ = self.mha(x_flat, x_flat, x_flat)  # self-attention
        out = x_flat + self.dropout(attn_out)           # residual
        out = self.ln(out)

        # Unflatten back -> (N, C, H, W)
        out = out.view(N, H, W, C).permute(0, 3, 1, 2)
        return out


##### BTNK1*

In [None]:
class BTNK1Star(nn.Module):
    """
    BTNK1*(C, W, C1, S) block:
      - 1x1 conv (stride=S) + BN + ReLU
      - MHSA (replaces 3x3 conv)
      - 1x1 conv (expand by 4) + BN
      - Skip connection with optional downsample
      - ReLU

    Args:
      C  : input channels
      W  : input spatial size (width/height) for reference
      C1 : mid-channels (bottleneck)
      S  : stride
    """
    def __init__(self, C, W, C1, S=2, num_heads=4, dropout=0.0):
        super().__init__()
        
        self.C = C
        self.W = W  # Not strictly needed for the layers, but we store it to match the diagram
        self.C1 = C1
        self.S = S
        
        self.out_channels = C1 * 4  # typical ResNet bottleneck expansion ratio

        # 1) 1×1 "reduce" conv
        self.conv1 = nn.Conv2d(C, C1, kernel_size=1, stride=S, bias=False)
        self.bn1   = nn.BatchNorm2d(C1)
        self.relu  = nn.ReLU(inplace=True)
        
        # 2) Multi-head self-attention in place of the 3×3 conv
        self.mhsa  = MHSA2D(embed_dim=C1, num_heads=num_heads, dropout=dropout)
        
        # 3) 1×1 "expand" conv
        self.conv2 = nn.Conv2d(C1, self.out_channels, kernel_size=1, stride=1, bias=False)
        self.bn2   = nn.BatchNorm2d(self.out_channels)
        
        # Downsample if needed
        self.downsample = None
        if (S != 1) or (C != self.out_channels):
            self.downsample = nn.Sequential(
                nn.Conv2d(C, self.out_channels, kernel_size=1, stride=S, bias=False),
                nn.BatchNorm2d(self.out_channels),
            )

    def forward(self, x):
        identity = x
        
        # 1) 1×1 reduce
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        # 2) MHSA
        out = self.mhsa(out)
        
        # 3) 1×1 expand
        out = self.conv2(out)
        out = self.bn2(out)
        
        # Skip connection
        if self.downsample is not None:
            identity = self.downsample(identity)
        
        out += identity
        out = self.relu(out)
        
        return out


##### BTNK2*

In [None]:
class BTNK2Star(nn.Module):
    """
    BTNK2*(C, W):
      - Input shape: (N, C, W, W)
      - 1x1 conv => C/4, BN, ReLU
      - MHSA (replaces 3x3 conv)
      - 1x1 conv => C, BN
      - Residual + ReLU => Output shape: (N, C, W, W)
    """
    def __init__(self, C, W, num_heads=4, dropout=0.0):
        super().__init__()
        self.C = C
        self.W = W
        
        # We'll reduce to C//4 mid-channels
        # (Ensure C is divisible by 4 or handle rounding.)
        mid_channels = C // 4

        # 1) 1x1 conv to reduce channels
        self.conv1 = nn.Conv2d(C, mid_channels, kernel_size=1, stride=1, bias=False)
        self.bn1   = nn.BatchNorm2d(mid_channels)
        self.relu  = nn.ReLU(inplace=True)
        
        # 2) MHSA in place of the usual 3x3
        self.mhsa  = MHSA2D(embed_dim=mid_channels, num_heads=num_heads, dropout=dropout)
        
        # 3) 1x1 conv to expand back to C
        self.conv2 = nn.Conv2d(mid_channels, C, kernel_size=1, stride=1, bias=False)
        self.bn2   = nn.BatchNorm2d(C)

    def forward(self, x):
        identity = x
        
        # 1) 1x1 reduce
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        # 2) MHSA
        out = self.mhsa(out)
        
        # 3) 1x1 expand
        out = self.conv2(out)
        out = self.bn2(out)
        
        # Residual
        out += identity
        out = self.relu(out)
        
        return out


##### CNN Feature Extractor

In [None]:
class CNN_feature_extractor(nn.Module):
    def __init__(self,num_classes=6):
        super(CNN_feature_extractor, self).__init__()

        self.relu = nn.ReLU()

        self.conv1 = nn.Conv2d(in_channels=1,out_channels=64,kernel_size=7,stride=2,padding=3)
        self.bn1 = nn.BatchNorm2d(64)

        self.maxPool1 = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)

        self.btnk1_1 = BTNK1(64, 56, 64, 1)
        self.btnk2_1 = BTNK2(256, 56)
        self.btnk2_2 = BTNK2(256, 56)

        self.btnk1_2 = BTNK1(256, 56, 128, 2)
        self.btnk2_3 = BTNK2(512, 28)
        self.btnk2_4 = BTNK2(512, 28)
        self.btnk2_5 = BTNK2(512, 28)

    def forward(self, x):

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxPool1(x)

        x = self.btnk1_1(x)
        x = self.btnk2_1(x)
        x = self.btnk2_2(x)

        x = self.btnk1_2(x)
        x = self.btnk2_3(x)
        x = self.btnk2_4(x)
        x = self.btnk2_5(x)
        return x

##### Transformer Learning

In [None]:
class Transformer_learning(nn.Module):
    def __init__(self,num_classes=6):
        super(Transformer_learning, self).__init__()

        self.btnk1star_1 = BTNK1Star(512, 28, 256, S=2)
        self.btnk2star_1 = BTNK2Star(1024,14)
        self.btnk2star_2 = BTNK2Star(1024,14)
        self.btnk2star_3 = BTNK2Star(1024,14)
        self.btnk2star_4 = BTNK2Star(1024,14)
        self.btnk2star_5 = BTNK2Star(1024,14)

        self.btnk1star_2 = BTNK1Star(1024, 14, 512,S=2)
        self.btnk2star_6 = BTNK2Star(2048,7)
        self.btnk2star_7 = BTNK2Star(2048,7)


    def forward(self, x):

        x = self.btnk1star_1(x)

        x = self.btnk2star_1(x)
        x = self.btnk2star_2(x)
        x = self.btnk2star_3(x)
        x = self.btnk2star_4(x)
        x = self.btnk2star_5(x)

        x = self.btnk1star_2(x)
        
        x = self.btnk2star_6(x)
        x = self.btnk2star_7(x)

        return x

##### SCR Loss + Center Loss

##### Main Model

In [None]:
class MainModel(nn.Module):
    def __init__(self, num_classes):
        super(MainModel, self).__init__()

        self.cnn_fe = CNN_feature_extractor(num_classes=num_classes)
        self.tl = Transformer_learning(num_classes=num_classes)

        self.globalPool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        # 1) Convolution + feature extraction
        x = self.cnn_fe(x)

        # 2) Transformer-based layer
        x = self.tl(x)

        # 3) Global average pool => shape: (batch_size, 2048, 1, 1)
        x = self.globalPool(x)

        # 4) Flatten => shape: (batch_size, 2048)
        features = torch.flatten(x, 1)

        # 5) Final FC for classification => shape: (batch_size, num_classes)
        logits = self.fc(features)

        # Return both for SCR-loss
        return logits, features


#### Loss

In [None]:
class SCRLoss(nn.Module):
    """
    Implements SCR-loss = Softmax (cross-entropy) + lambda * Center Loss.
    """
    def __init__(self, num_classes, feat_dim, lambda_c=1.0):
        """
        Args:
            num_classes (int): number of classes.
            feat_dim (int): dimensionality of the feature space.
            lambda_c (float): weighting factor for the center loss term.
        """
        super(SCRLoss, self).__init__()
        # Trainable parameter that holds the center for each class
        self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
        self.lambda_c = lambda_c
        # We use CrossEntropyLoss as the standard "softmax loss"
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, logits, features, labels):
        """
        Args:
            logits (torch.Tensor): raw class scores (pre-softmax), shape [batch_size, num_classes].
            features (torch.Tensor): the learned features from your model, shape [batch_size, feat_dim].
            labels (torch.LongTensor): ground truth class labels, shape [batch_size].

        Returns:
            total_loss (torch.Tensor): SCR-loss = cross_entropy + lambda_c * center_loss
            ce_loss (torch.Tensor): the cross-entropy component
            c_loss (torch.Tensor): the center-loss component
        """
        # 1) Cross-entropy (softmax) loss
        ce_loss = self.cross_entropy(logits, labels)

        # 2) Center loss
        # Index into the centers by label
        batch_centers = self.centers[labels]  # shape [batch_size, feat_dim]
        # Compute the mean squared distance
        # (x_i - c_{y_i})^2 summed over the batch
        diff = features - batch_centers
        c_loss = 0.5 * diff.pow(2).sum(dim=1).mean()

        total_loss = ce_loss + self.lambda_c * c_loss
        return total_loss, ce_loss, c_loss


##### Model Summary

In [None]:
model = MainModel(6)
summary(model , input_size=(1, 1, 224, 224))

#### 3.3 Actual Training

In [None]:
# Prepare DataLoaders
train_loader = prepare_dataloader(
    X_train,
    y_train,
    batch_size=32,
    shuffle=True,
)

In [None]:
num_classes = len(np.unique(y_encoded))

model = MainModel(num_classes=num_classes).to(device)

# Define loss function and optimizer
lambda_c = 0.3
scr_loss_fn = SCRLoss(num_classes=num_classes, feat_dim=2048, lambda_c=lambda_c).to(device)

In [None]:
learning_rate = 1e-3  # Typical for Adam
center_lr = 0.05  # Smaller LR for center loss updates

# Optimizer for the model
model_optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Optimizer for the center-loss parameters
center_optimizer = torch.optim.SGD(scr_loss_fn.parameters(), lr=center_lr)


In [None]:
# Train the model
epoch_count = 1
loss_history = train_model_separate_optim(
    model, 
    train_loader, 
    scr_loss_fn, 
    model_optimizer, 
    center_optimizer,  
    device, 
    epochs=epoch_count, 
    patience=50, 
    min_delta=0.01
)


In [None]:
plot_loss_curve(loss_history)

In [None]:
display_confusion_matrix(model, train_loader, device)

#### 3.4 Save Model to File

In [None]:
mds = modulation_types[0]
if len(modulation_types) == 3:
    mds = "ALL"

model_file_name = f"RM_Net_model_e{epoch_count}_lr{learning_rate}_snr_{snr}_mds_{mds}"

In [None]:
torch.save(model, model_file_name+".pth")

### 4 Testing

In [None]:
def evaluate_model(model, test_loader, label_encoder, device):
    """
    Evaluates the trained model and displays accuracy, confusion matrix, and F1-score.
    
    Args:
        model: Trained PyTorch model.
        test_loader: DataLoader for test set.
        label_encoder: Label encoder to decode class names.
        device: 'cuda' or 'cpu' where evaluation happens.
    """
    model.to(device)  # Ensure model is on correct device
    model.eval()  # Set to evaluation mode

    y_true = []
    y_pred = []
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Model returns (logits, features) → Extract logits
            logits, _ = model(inputs)
            
            # Get predicted class (argmax over logits)
            _, predicted = torch.max(logits, 1)
            
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            y_true.extend(labels.cpu().tolist())  # Move to CPU for metrics
            y_pred.extend(predicted.cpu().tolist())

    # Compute Accuracy
    accuracy = 100 * correct / total
    print(f"Accuracy: {accuracy:.2f}%")

    # Compute & Display Confusion Matrix
    class_names = label_encoder.classes_  # Decode label names
    cm = confusion_matrix(y_true, y_pred)
    # Normalize confusion matrix to percentages
    cm_normalized = cm.astype(np.float32) / cm.sum(axis=1, keepdims=True) * 100
    num_classes = len(class_names)

    # Plot confusion matrix
    fig, ax = plt.subplots(figsize=(max(10, num_classes * 0.8), max(8, num_classes * 0.6)))  # Dynamic size
    disp = ConfusionMatrixDisplay(confusion_matrix=cm_normalized, display_labels=class_names)
    disp.plot(cmap="Blues", values_format=".1f", ax=ax)  # Use 1 decimal place for percentages

    # Adjust x-axis label alignment and font sizes
    ax.set_xticklabels(class_names, rotation=45, ha="right", va="top", fontsize=max(8, 12 - num_classes // 5))
    ax.set_yticklabels(class_names, rotation=0, fontsize=max(8, 12 - num_classes // 5))
    ax.set_xlabel("Predicted Label", fontsize=12, labelpad=10)
    ax.set_ylabel("True Label", fontsize=12, labelpad=10)
    ax.set_title("Confusion Matrix (Percentage)", fontsize=14)

    # Adjust layout with extra bottom margin for rotated labels
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.2 + num_classes * 0.005)  # Dynamic bottom margin

    plt.show()

    # Print Classification Report (Precision, Recall, F1-score)
    print("\nClassification Report:")
    print(classification_report(y_true, y_pred, target_names=class_names))

In [None]:
# Prepare DataLoaders
test_loader = prepare_dataloader(
    X_test,
    y_test,
    batch_size=32,
)

In [None]:
# Evaluate the model
evaluate_model(model, test_loader, label_encoder,device)