## PyTorch VGGNet13 example for RadarML Dataset

#### Import all required libraries

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.io as io
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import os
import numpy as np
from tqdm import tqdm

In [2]:
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

#### Detect if model will run on CPU or GPU

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


#### Spectogram Generation Function

In [4]:
import matplotlib
import matplotlib.pyplot as plt
import cv2

matplotlib.use('Agg')

fs = 3.84e9/32    # Final sample rate given by RF sample rate divided by the total decimation

def normalise_to_uint8(value,db=0,dynamic_range=40):

    """
    Normalize a float64 array or a single float64 value to uint8.
    If input is a single value, it will return the normalized value as uint8.
    If input is an array, it returns a normalized uint8 array.
    """
    if isinstance(value, np.ndarray):
        # If it's an array, normalize the entire array

        if db:
            #print("Dynamic Range")
            max_val = np.max(value)
            min_val = np.max([int(max_val-dynamic_range),int(np.min(value))],)
            value[value<min_val]=min_val
        else:
            #print("Using frame normalisation with no dynamic range limits")
            min_val = np.min(value)
            max_val = np.max(value)
        
        # Normalize and clip to the range [0, 255]
        
        
        normalized = np.clip((value - min_val) / (max_val - min_val) * 255, 0, 255)
        
        # Convert to uint8
        return normalized.astype(np.uint8)
    else:
        # If it's a single float64 value, normalize it (with respect to its own range)
        return np.uint8(np.clip(value * 255, 0, 255))  # Normalize to the range [0, 255] and convert to uint8


def SpectrogramGenerator(cmplx_data, filename):

    NFFT = 256
    noverlap = 128
    L = cmplx_data.shape[0]
    step = NFFT - noverlap       # hop size (128 for 256 FFT w/ 50% overlap)
    if L < NFFT:
        num_segments = 0
    else:
        num_segments = 1 + (L - NFFT) // step

    freq_data = np.zeros((num_segments, NFFT), dtype=complex)
    window = np.hanning(NFFT)

    for i in range(num_segments):
        start = i * step
        segment = cmplx_data[start : start + NFFT] * window
        freq_data[i, :] = np.fft.fft(segment, NFFT)

    freq_data = np.fft.fftshift(freq_data, axes=1)
    freq_data = freq_data.T
    
    NQ = fs / 2
    ylength = freq_data.shape[0]

    power = np.abs(freq_data) ** 2
    power_db = 10 * np.log10(power + 1e-12) # Add small ammount to avoid log 0 

    img=np.zeros((freq_data.shape[0], freq_data.shape[1], 1))
    img[:,:,0] = normalise_to_uint8(power_db)
    img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
    cv2.imwrite(filename, img)
    
    del img

## Generate Spectrograms from processed (SNR degraded) numpy files

#### Make sure the "target_SNR_dBS" list matches that used in the "make_dataset.py" script

In [None]:
from collections import defaultdict
import multiprocessing
from concurrent.futures import ProcessPoolExecutor

# This is used to map the index in the dataset to an SNR value (This needs to match the "make_dset.py" script)
target_SNR_dBs = [30, 27, 24, 21, 18, 15, 12, 9, 6, 3, 0, -3, -6, -9, -12, -15, -18, -21, -24, -27, -30]

# Define what SNRs to use in training and testing
train_SNR_dBs = [snr for snr in target_SNR_dBs if snr >= 0]    # Use SNRs above and including 0dB
test_SNR_dBs  = target_SNR_dBs                                 # Use all SNRs for testing

# Define what train/test split to use
train_test_split = 0.8

# Core function to process a single .npy file
def process_file(data):

    file_path = data[0]
    total_pulses = data[1]

    fname = os.path.basename(file_path)
    name_without_ext = os.path.splitext(fname)[0]

    modulation_part = name_without_ext.split('_')[1]
    modulation_name = modulation_part.split('#')[0]
    
    raw_data = np.load(file_path)

    #for channel_idx in [0]:
    for channel_idx in range(len(raw_data)):
        # Add modulation_name as a subfolder here:
        channel_train_folder = os.path.join(root_spectrogram_folder, f"ADC{channel_idx}", "train", modulation_name)
        channel_test_folder  = os.path.join(root_spectrogram_folder, f"ADC{channel_idx}", "test", modulation_name)

        os.makedirs(channel_train_folder, exist_ok=True)
        os.makedirs(channel_test_folder, exist_ok=True)

        max_available_pulses = raw_data.shape[2]  # Assumes shape = [channels][snrs][pulses]
        pulses_to_use = min(total_pulses, max_available_pulses)

        #print(fname)
        #print(max_available_pulses)
        #print(pulses_to_use)
        
        split_index = int(pulses_to_use * train_test_split)

        for pulse_idx in range(pulses_to_use):
            is_training_pulse = pulse_idx < split_index

            for snr_idx, snr_db in enumerate(target_SNR_dBs):
                if is_training_pulse:
                    if snr_db not in train_SNR_dBs:
                        continue
                    current_dir = channel_train_folder
                else:
                    if snr_db not in test_SNR_dBs:
                        continue
                    current_dir = channel_test_folder

                row_data = raw_data[channel_idx][snr_idx][pulse_idx]
                cmplx_data = row_data[0::2] + 1.0j * row_data[1::2]
                filename = f"{name_without_ext}_{channel_idx}_{snr_db}dB_{pulse_idx}.png"
                filepath = os.path.join(current_dir, filename)
                SpectrogramGenerator(cmplx_data, filepath)

# Define root folder to store our Spectrogram Datasets
root_spectrogram_folder = os.path.join(os.getcwd(), "SpectrogramData")
os.makedirs(root_spectrogram_folder, exist_ok=True)

root_npy_folder = os.path.join(os.getcwd(), "processed")

# Count waveform occurences
waveform_counts = defaultdict(int)
type_to_files = defaultdict(list)

for file in os.listdir(root_npy_folder):
    if not file.endswith(".npy"):
        continue

    # Get waveform type (e.g. 'barker5x13', 'fsk', 'nlfm')
    waveform_type = file.split('_')[1].lower()
    waveform_type = waveform_type.split('#')[0]  # Remove trailing # index

    # Count and group
    waveform_counts[waveform_type] += 1
    type_to_files[waveform_type].append(file)

# Print number of waveform types
for waveform, count in waveform_counts.items():
    print(f"{waveform} = {count}")

# Determine balancing values
max_count = max(waveform_counts.values())
chosen_count = max_count * 5  # You can tweak this multiplier
#print(f"\nMax count = {max_count}")
#print(f"Chosen per-type total = {chosen_count}")
file_repeat_pairs = []

for waveform_type, files in type_to_files.items():
    n_files = len(files)
    repeats_per_file = int(np.ceil(chosen_count / n_files))

    for file in files:
        file_path = os.path.join(root_npy_folder, file)
        file_repeat_pairs.append((file_path, repeats_per_file))

# Print number of repeats per waveform to achieve balanced set
print("\nSample of file_repeat_pairs:")
for pair in file_repeat_pairs:
    print(f"{os.path.basename(pair[0])}, {pair[1]}")

num_cpus = multiprocessing.cpu_count()
max_workers = max(1, int(num_cpus * 0.8))
with ProcessPoolExecutor(max_workers=max_workers) as executor:
    list(tqdm(executor.map(process_file, file_repeat_pairs), total=len(file_repeat_pairs)))

#### PyTorch Dataloader code

In [5]:
import os
from torch.utils.data import Dataset
from torchvision import io

class SpectrogramDataset(Dataset):
    def __init__(self, root_dir, class_to_idx=None):
        self.root_dir = root_dir

        # If no mapping is provided, create it
        if class_to_idx is None:
            # Map class names to integer labels
            classes = sorted(entry.name for entry in os.scandir(root_dir) if entry.is_dir())
            self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}
        else:
            self.class_to_idx = class_to_idx

        # This 2D array stores [image_path, label] pairs
        self.data = []
        for cls_name, label in self.class_to_idx.items():
            cls_folder = os.path.join(root_dir, cls_name)
            if not os.path.isdir(cls_folder):
                continue
            for fname in os.listdir(cls_folder):
                if fname.lower().endswith('.png'):
                    full_path = os.path.join(cls_folder, fname)
                    self.data.append([full_path, label])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]

        # Load image and normalize
        image_tensor = io.read_image(img_path).float() / 255.0

        return image_tensor, label

#### Definition of VGG based network single channel network

In [6]:
class VGG13_net(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        # The input data is 1x224x224

        self.features = nn.Sequential(

            nn.Conv2d(1, 24, (3, 3), 1, 1),            # Output = 24*224*224
            nn.BatchNorm2d(24),
            nn.ReLU(inplace=True),
            nn.Conv2d(24, 24, (3, 3), 1, 1),           # Output = 24*224*224
            nn.BatchNorm2d(24),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),                        # Output = 24*112*112

            nn.Conv2d(24, 48, (3, 3), 1, 1),           # Output = 48*112*112
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True),
            nn.Conv2d(48, 48, (3, 3), 1, 1),           # Output = 48*112*112
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),                        # Output = 48*56*56

            nn.Conv2d(48, 96, (3, 3), 1, 1),           # Output = 96*56*56
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True),
            nn.Conv2d(96, 96, (3, 3), 1, 1),           # Output = 96*56*56
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),                        # Output = 96*28*28

            nn.Conv2d(96, 192, (3, 3), 1, 1),          # Output = 192*28*28
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, (3, 3), 1, 1),         # Output = 192*28*28
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),                        # Output = 192*14*14

            nn.Conv2d(192, 192, (3, 3), 1, 1),         # Output = 192*14*14
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, (3, 3), 1, 1),         # Output = 192*14*14
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),                        # Output = 192*7*7
        )

        # Final feature map = 192 * 7 * 7 = 9408
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(192 * 7 * 7, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),

            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),

            nn.Linear(1024, num_classes)
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, a=0.1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


## Training and Validation Dataset Configuration

In [8]:
adc_num = 0
vgg_batchSize = 32

train_dataset = SpectrogramDataset(root_dir=f'SpectrogramData/ADC{adc_num}/train')
train_loader = DataLoader(train_dataset, batch_size=vgg_batchSize, shuffle=True)

test_dataset = SpectrogramDataset(root_dir=f'SpectrogramData/ADC{adc_num}/test')
test_loader = DataLoader(test_dataset, batch_size=vgg_batchSize, shuffle=False)

## Model Training

In [15]:
# Setup some tunable parameters for training
vgg_epochs = 100
vgg_momentum = 0.9
vgg_learningRate = 1e-4
vgg_weightDecay = 5e-4

num_unique_waveforms = len(train_dataset.class_to_idx)

model = VGG13_net(num_classes=num_unique_waveforms)
model.to(device)

# Setup loss function and optimiser
criterion = torch.nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(model.parameters(), lr=vgg_learningRate)

best_accuracy = 0.0
patience = 10 
epochs_no_improve = 0
best_model_path = f"weights/adc{adc_num}_best_model.pth"
last_model_path = f"weights/adc{adc_num}_last_model.pth"

# Train model
for epoch in range(vgg_epochs):
    model.train()
    running_loss = 0.0

    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{vgg_epochs} - Training", leave=False)
    
    for images, labels in train_bar:
        images, labels = images.to(device), labels.to(device)

        optimiser.zero_grad()
        outputs = model(images)
        
        loss = criterion(outputs, labels)
        loss.backward()

        optimiser.step()

        running_loss += loss.item() * images.size(0)
        train_bar.set_postfix(loss=loss.item())

    epoch_loss = running_loss / len(train_dataset)
    print(f"\n[Epoch {epoch+1}] Training Loss: {epoch_loss:.4f}")

    # Evaluation
    model.eval()
    correct = 0
    total = 0

    test_bar = tqdm(test_loader, desc=f"Epoch {epoch+1}/{vgg_epochs} - Testing", leave=False)
    with torch.no_grad():
        for images, labels in test_bar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            test_bar.set_postfix(accuracy=100 * correct / total)

    accuracy = 100 * correct / total
    print(f"[Epoch {epoch+1}] Test Accuracy: {accuracy:.2f}%\n")

    # Save best model based on accuracy
    if accuracy >= best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), best_model_path)
        print(f"[Epoch {epoch+1}] ✅ Best model saved with accuracy: {accuracy:.2f}%")

    # Always save last model so that training can be resumed
    torch.save(model.state_dict(), last_model_path)

    # Check early stopping condition
    if epochs_no_improve >= patience:
        print(f"⏹️ Early stopping triggered after {patience} epochs with no improvement.")
        break  # exits the epoch loop


                                                                                

KeyboardInterrupt: 

## Validation (Confusion Matrix Per SNR)

In [14]:
from collections import defaultdict
import re

# Organise test data by modulation type and SNR so we can form Conf matrix
def organise_data_by_modulation_and_snr(root_dir):
    data_structure = defaultdict(lambda: defaultdict(list))
    
    for modulation_folder in os.listdir(root_dir):
        modulation_path = os.path.join(root_dir, modulation_folder)
        
        if not os.path.isdir(modulation_path):
            continue
        
        for filename in os.listdir(modulation_path):
            if not filename.lower().endswith('.png'):
                continue
            
            match = re.search(r'(-?\d+)dB', filename, re.IGNORECASE)
            if match:
                snr = int(match.group(1))
                full_path = os.path.join(modulation_path, filename)
                data_structure[modulation_folder][snr].append(full_path)
    
    return data_structure


# Dataset that only contains desired sub-sets
class FilteredDataset(Dataset):
    def __init__(self, file_paths, class_to_idx):
        self.file_paths = file_paths
        self.class_to_idx = class_to_idx
        
    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        modulation = os.path.basename(os.path.dirname(img_path))
        label = self.class_to_idx[modulation]
        image_tensor = io.read_image(img_path).float() / 255.0
        return image_tensor, label

# setup dataset
root_dir = f'SpectrogramData/ADC{adc_num}/test'
data_structure = organise_data_by_modulation_and_snr(root_dir)

# Create classes/index mappings
classes = sorted(entry.name for entry in os.scandir(root_dir) if entry.is_dir())
class_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}
idx_to_class = {idx: cls_name for cls_name, idx in class_to_idx.items()}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_unique_waveforms = len(test_dataset.class_to_idx)

# Load our model and pre-trained weights
model = VGG13_net(num_unique_waveforms).to(device)
model.load_state_dict(torch.load(f"weights/adc{adc_num}_best_model.pth", map_location=device))
model.eval()

# Group files by SNR across all modulations
snr_to_files = defaultdict(list)
for modulation in data_structure:
    for snr in data_structure[modulation]:
        snr_to_files[snr].extend(data_structure[modulation][snr])

# Process each SNR level
for snr in sorted(snr_to_files.keys()):
    all_labels = []
    all_predictions = []
    
    # Create dataset subset for this SNR
    filtered_dataset = FilteredDataset(snr_to_files[snr], class_to_idx)
    dataloader = DataLoader(filtered_dataset, batch_size=vgg_batchSize, shuffle=False)
    
    # Collect all predictions and labels
    with torch.no_grad():
        for images, labels in dataloader:

           images = images.to(device)
           outputs = model(images)
           predictions = outputs.argmax(dim=1)
            
           all_labels.extend(labels.cpu().numpy())
           all_predictions.extend(predictions.cpu().numpy())
    
    # Convert to indexes back to class names
    true_labels = [idx_to_class[idx] for idx in all_labels]
    pred_labels = [idx_to_class[idx] for idx in all_predictions]
    
    # Create confusion matrix
    cm = confusion_matrix(
        true_labels, 
        pred_labels, 
        labels=classes,
        normalize='true'
    )

    cm_display = ConfusionMatrixDisplay(
        confusion_matrix=cm, 
        display_labels=classes
    )

    disp = ConfusionMatrixDisplay(cm, display_labels=classes)
    fig, ax = plt.subplots(figsize=(10, 10))
    disp.plot(ax=ax, cmap="viridis", values_format=".2f", colorbar=False)
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
    plt.title(f'Confusion Matrix at SNR = {snr}dB for ADC{adc_num}')
    plt.tight_layout()
    plt.savefig(f'adc{adc_num}_confusion_matrix_snr_{snr}dB.png', dpi=150)
    plt.close()


Processing SNR -30dB: 2304 samples

Processing SNR -27dB: 2304 samples

Processing SNR -24dB: 2304 samples

Processing SNR -21dB: 2304 samples

Processing SNR -18dB: 2304 samples

Processing SNR -15dB: 2304 samples

Processing SNR -12dB: 2304 samples

Processing SNR -9dB: 2304 samples

Processing SNR -6dB: 2304 samples

Processing SNR -3dB: 2304 samples

Processing SNR 0dB: 2304 samples

Processing SNR 3dB: 2304 samples

Processing SNR 6dB: 2304 samples

Processing SNR 9dB: 2304 samples

Processing SNR 12dB: 2304 samples

Processing SNR 15dB: 2304 samples

Processing SNR 18dB: 2304 samples

Processing SNR 21dB: 2304 samples

Processing SNR 24dB: 2304 samples

Processing SNR 27dB: 2304 samples

Processing SNR 30dB: 2304 samples
