In [1]:

from torchsig.datasets.wideband_sig53 import WidebandSig53
from torchmetrics.detection import MeanAveragePrecision
from torch.utils.data import DataLoader
from torchsig.transforms.target_transforms import DescToMaskClass
from torchsig.transforms.transforms import Spectrogram, Normalize, Compose, Identity
from tqdm import tqdm
import matplotlib as mpl
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import pandas as pd
import numpy as np
import torch
import time
import os
from copy import deepcopy
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
import numpy as np
import pywt
import torch
from matplotlib import patches
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from scipy import ndimage
from scipy import signal as sp
from torch.utils.data import dataloader
from sklearn.preprocessing import LabelEncoder

In [2]:
from torchsig.utils.visualize import MaskClassVisualizer, mask_class_to_outline, complex_spectrogram_to_magnitude
from torchsig.transforms.target_transforms import DescToMaskClass
from torchsig.datasets.wideband import WidebandModulationsDataset
from torchsig.transforms.transforms import Spectrogram, Normalize, Compose
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [3]:
modulation_list = [
    "ook","bpsk","4pam","4ask","qpsk","8pam","8ask","8psk","16qam","16pam",
    "16ask","16psk","32qam","32qam_cross","32pam","32ask","32psk","64qam","64pam","64ask",
    "64psk","128qam_cross","256qam","512qam_cross","1024qam","2fsk","2gfsk","2msk","2gmsk","4fsk",
    "4gfsk","4msk","4gmsk","8fsk","8gfsk","8msk","8gmsk","16fsk","16gfsk","16msk","16gmsk",
    "ofdm-64","ofdm-72","ofdm-128","ofdm-180","ofdm-256","ofdm-300","ofdm-512","ofdm-600",
    "ofdm-900","ofdm-1024","ofdm-1200","ofdm-2048",
]    

fft_size = 512
num_classes = len(modulation_list)
num_iq_samples = fft_size * fft_size
num_samples = 150

data_transform = Compose([
    Spectrogram(nperseg=fft_size, noverlap=0, nfft=fft_size, mode='complex'),
    Normalize(norm=np.inf, flatten=True),
])

target_transform = Compose([
    DescToMaskClass(num_classes=num_classes, width=fft_size, height=fft_size),
])

In [4]:
wideband_modulations_dataset = WidebandModulationsDataset(
    modulation_list=modulation_list,
    level=1,
    num_iq_samples=num_iq_samples,
    num_samples=num_samples,
    transform=data_transform,
    target_transform=target_transform,
)

In [5]:

from torch.utils.data import random_split

# Specify the split ratio (e.g., 0.8 for 80% train and 20% test)
train_ratio = 0.8
test_ratio = 1 - train_ratio

# Calculate the number of samples for each split
num_samples = len(wideband_modulations_dataset)
train_size = int(train_ratio * num_samples)
test_size = num_samples - train_size

# Perform the random split
train_dataset, test_dataset = random_split(wideband_modulations_dataset, [train_size, test_size])

In [6]:


def complex_spectrogram_to_magnitude(tensor: np.ndarray) -> np.ndarray:
    """Visualizer data transform: Transform two channel spectrogram data for
    spectrogram magnitude visualization (mode = 'complex')

    """
    batch_size = tensor.shape[0]
    new_tensor = np.zeros((batch_size, tensor.shape[2], tensor.shape[3]), dtype=np.float64)
    for idx in range(tensor.shape[0]):
        new_tensor[idx] = 20 * np.log10(tensor[idx, 0] ** 2 + tensor[idx, 1] ** 2)
    return new_tensor

def mask_class_to_outline(tensor: np.ndarray) -> Tuple[List[List[int]], List[Any]]:
    """Target Transform: Transforms masks for each burst to individual outlines
    for the MaskClassVisualizer. Overlapping mask outlines are still shown as
    overlapping. Each bursts' class index is also returned.

    """
    batch_size = tensor.shape[0]
    labels = []
    class_idx = []
    struct = ndimage.generate_binary_structure(2, 2)
    for idx in range(batch_size):
        label = tensor[idx].numpy()
        class_idx_curr = []
        for individual_burst_idx in range(label.shape[0]):
            if np.count_nonzero(label[individual_burst_idx]) > 0:
                class_idx_curr.append(individual_burst_idx)
            label[individual_burst_idx] = label[individual_burst_idx] - ndimage.binary_erosion(
                label[individual_burst_idx]
            )
        label = np.sum(label, axis=0)
        label[label > 0] = 1
        label = ndimage.binary_dilation(label, structure=struct, iterations=2).astype(label.dtype)
        label = np.ma.masked_where(label == 0, label)
        class_idx.append(class_idx_curr)
        labels.append(label)
    return class_idx, labels

In [7]:


class Visualizer:
    """A non-entirely abstract class which represents a visualization of a dataset

    Args:
        data_loader:
            A Dataloader to sample from for plotting

        visualize_transform:
            Defines how to transform the data prior to plotting

        visualize_target_transform:
            Defines how to transform the target prior to plotting

    """

    def __init__(
        self,
        data_loader,
        visualize_transform: Optional[Callable] = None,
        visualize_target_transform: Optional[Callable] = None,
    ) -> None:
        self.data_loader = iter(data_loader)
        self.visualize_transform = visualize_transform
        self.visualize_target_transform = visualize_target_transform

    def __iter__(self) -> Iterable:
        self.data_iter = iter(self.data_loader)
        return self  # type: ignore

    def __next__(self) -> Figure:
        iq_data, targets = next(self.data_iter)
        if self.visualize_transform:
            iq_data = self.visualize_transform(iq_data)

        if self.visualize_target_transform:
            targets = self.visualize_target_transform(targets)

        return self._visualize(iq_data, targets)

    def _visualize(self, iq_data: np.ndarray, targets: np.ndarray) -> Figure:
        raise NotImplementedError

In [13]:
import os
import cv2
import numpy as np
from typing import List, Optional, Callable
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
from copy import deepcopy
from matplotlib.figure import Figure


class MaskClassVisualizer(Visualizer):
    def __init__(
        self,
        data_loader,
        visualize_transform: Optional[Callable] = None,
        visualize_target_transform: Optional[Callable] = None,
        class_list: Optional[List[str]] = None,
        save_dir: Optional[str] = None,
    ) -> None:
        super(MaskClassVisualizer, self).__init__(data_loader, visualize_transform, visualize_target_transform)
        self.class_list = class_list
        self.save_dir = save_dir
        os.makedirs(self.save_dir, exist_ok=True)
        self.images_dir = os.path.join(self.save_dir, "images")
        os.makedirs(self.images_dir, exist_ok=True)
        self.labels_dir = os.path.join(self.save_dir, "labels")
        os.makedirs(self.labels_dir, exist_ok=True)

    def __next__(self) -> Figure:
        iq_data, targets = next(self.data_iter)
        if self.visualize_transform:
            iq_data = self.visualize_transform(deepcopy(iq_data))

        if self.visualize_target_transform:
            classes, targets = self.visualize_target_transform(deepcopy(targets))
        else:
            targets = None

        return self._visualize(iq_data, targets, classes)

    def _visualize(self, data: np.ndarray, targets: np.ndarray, classes: List[str]) -> Figure:
        # Create a label encoder and fit it on the class list
        label_encoder = LabelEncoder()
        label_encoder.fit(self.class_list)

        batch_size = data.shape[0]
        for sample_idx in range(batch_size):
            fig, ax = plt.subplots(figsize=(8, 6))
            ax.imshow(
                data[sample_idx],
                vmin=np.min(data),
                vmax=np.max(data),
                cmap="jet",
                extent=[0, data.shape[2], 0, data.shape[1]],
            )
            ax.axis('off')

            if targets is not None:
                class_idx = classes[sample_idx]
                mask = targets[sample_idx]
                mask_img = np.ma.masked_where(mask < 0.5, mask)
                ax.imshow(
                    mask_img,
                    vmin=np.min(mask),
                    vmax=np.max(mask),
                    cmap="gray",
                    alpha=0.5,
                    interpolation="none",
                    extent=[0, mask.shape[1], 0, mask.shape[0]],
                )

            if targets is not None:
                class_labels = [self.class_list[idx] for idx in class_idx]
                # Encode the class labels as numeric values
                encoded_labels = label_encoder.transform(class_labels)
                bbox = self._calculate_bounding_box(mask)
                title_with_bbox = [f"{class_label} {bbox_str}" for class_label, bbox_str in zip(encoded_labels, bbox)]
                

            # Save the figure as an individual image
            filename = f"spectrogram_{sample_idx}.png"
            filepath = os.path.join(self.images_dir, filename)
            plt.savefig(filepath, dpi=300, bbox_inches='tight', pad_inches=0)
            plt.close(fig)

            # Save labels and bounding boxes in a text file
            labels_filepath = os.path.join(self.labels_dir, f"labels_{sample_idx}.txt")
            with open(labels_filepath, "w") as f:
                for class_label, bbox in zip(encoded_labels, bbox):
                    bbox_str = " ".join([str(coord) for coord in bbox])
                    f.write(f"{class_label} {bbox_str}\n")

        return fig

    @staticmethod
    def _calculate_bounding_box(mask: np.ndarray) -> List[List[float]]:
        contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        bboxes = []
        for contour in contours:
            x_coords = contour[:, :, 0].flatten()
            y_coords = contour[:, :, 1].flatten()
            x_min = np.min(x_coords)
            y_min = np.min(y_coords)
            x_max = np.max(x_coords)
            y_max = np.max(y_coords)

            xc = (x_min + x_max) / 2.0
            yc = (y_min + y_max) / 2.0
            w = x_max - x_min
            h = y_max - y_min
            
            xc = xc/512
            yc = yc/512
            w = w/512
            h = h/512

            bbox = [xc, yc, w, h]
            bboxes.append(bbox)
        return bboxes


In [14]:
data_loader = DataLoader(
    dataset=train_dataset,
    batch_size=120,
    shuffle=True,
)

save_dir = "train"

visualizer = MaskClassVisualizer(
    data_loader=data_loader,
    visualize_transform=complex_spectrogram_to_magnitude,
    visualize_target_transform=mask_class_to_outline,
    class_list=modulation_list,
    save_dir=save_dir,
)


for figure in iter(visualizer):
    figure.set_size_inches(16, 9)
    plt.show()
    break

In [15]:
data_loader = DataLoader(
    dataset=test_dataset,
    batch_size=25,
    shuffle=True,
)

save_dir = "test"

visualizer = MaskClassVisualizer(
    data_loader=data_loader,
    visualize_transform=complex_spectrogram_to_magnitude,
    visualize_target_transform=mask_class_to_outline,
    class_list=modulation_list,
    save_dir=save_dir,
)


for figure in iter(visualizer):
    figure.set_size_inches(16, 9)
    plt.show()
    break