# Training a Model on IQ Samples for Classification

This notebook demonstrates how to train a PyTorch model on IQ Samples for modulation recognition.

---

In [None]:
# Variables
from torchsig.signals.signal_lists import TorchSigSignalLists
from torchsig.transforms.transforms import ComplexTo2D
import os

from torch import Tensor

root = "./datasets/classifier_example"
os.makedirs(root, exist_ok=True)
os.makedirs(root + "/train", exist_ok=True)
os.makedirs(root + "/val", exist_ok=True)
os.makedirs(root + "/test", exist_ok=True)
fft_size = 256
num_iq_samples_dataset = fft_size**2
class_list = TorchSigSignalLists.all_signals
family_list = TorchSigSignalLists.family_list
num_classes = len(class_list)
num_samples_train = len(class_list) * 5  # roughly 5 samples per class
num_samples_val = len(class_list) * 2
impairment_level = 0
seed = 123456789
# IQ-based mod-rec only operates on 1 signal
num_signals_max = 1
num_signals_min = 1

# ComplexTo2D turns a IQ array of complex values into a 2D array, with one channel for the real component, while the other is for the imaginary component
transforms = [ComplexTo2D()]

dataset_metadata = {
    "num_iq_samples_dataset": num_iq_samples_dataset,
    "fft_size": fft_size,
    "fft_stride": fft_size,
    "num_signals_max": num_signals_max,
    "num_signals_min": num_signals_min,
    "noise_power_db": 1,
    "signal_center_freq_min": 1000,
    "signal_center_freq_max": 2000,
    "sample_rate": 10000,
    "frequency_min": 1000,
    "frequency_max": 2000,
    "cochannel_overlap_probability": 0.2,
    "signal_duration_in_samples_min": 2000,
    "signal_duration_in_samples_max": 8000,
    "bandwidth_min": 1000,
    "bandwidth_max": 2000,
}

## Create the Dataset

In [None]:
from torchsig.datasets.datasets import TorchSigIterableDataset, StaticTorchSigDataset
from torchsig.utils.data_loading import WorkerSeedingDataLoader
from torchsig.utils.writer import DatasetCreator

train_dataset = TorchSigIterableDataset(
    metadata=dataset_metadata,
    transforms=transforms,
    target_labels=None,
    signal_generators="all",
)
val_dataset = TorchSigIterableDataset(
    metadata=dataset_metadata, transforms=transforms, target_labels=None
)

class_list = train_dataset.class_names

train_dataloader = WorkerSeedingDataLoader(
    train_dataset, batch_size=4, collate_fn=lambda x: x
)
val_dataloader = WorkerSeedingDataLoader(val_dataset, collate_fn=lambda x: x)

# print(f"Data shape: {data.shape}")
# print(f"Targets: {targets}")
# next(train_dataset)

dc = DatasetCreator(
    dataloader=train_dataloader,
    root=f"{root}/train",
    overwrite=True,
    dataset_length=num_samples_train,
)
dc.create()


dc = DatasetCreator(
    dataloader=val_dataloader,
    root=f"{root}/val",
    overwrite=True,
    dataset_length=num_samples_val,
)
dc.create()

In [None]:
train_dataset = StaticTorchSigDataset(
    root=f"{root}/train", target_labels=["class_index"]
)
val_dataset = StaticTorchSigDataset(root=f"{root}/val", target_labels=["class_index"])

train_dataloader = WorkerSeedingDataLoader(train_dataset, batch_size=4)
val_dataloader = WorkerSeedingDataLoader(val_dataset)

print(train_dataset[3])

In [None]:
next(iter(train_dataloader))

## Create the Model

We use our own XCIT model code and utils, but this can be replaced with your own model arhcitecture in PyTorch, Ultralytics, timm, ect.

In [None]:
!pip install timm pytorch_lightning

In [None]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm


class ConvDownSampler(nn.Module):
    def __init__(self, in_chans: int, embed_dim: int, ds_rate: int = 16):
        super().__init__()
        # Use a single convolutional layer with appropriate stride
        self.conv = nn.Conv1d(
            in_channels=in_chans,
            out_channels=embed_dim,
            kernel_size=ds_rate * 2,
            stride=ds_rate,
            padding=ds_rate // 2,
        )
        self.bn = nn.BatchNorm1d(embed_dim)
        self.act = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        return x


class Chunker(nn.Module):
    def __init__(self, in_chans: int, embed_dim: int, ds_rate: int = 16):
        super().__init__()
        self.ds_rate = ds_rate
        self.embed = nn.Conv1d(in_chans, embed_dim, kernel_size=7, padding=3)
        self.pool = nn.AvgPool1d(kernel_size=ds_rate, stride=ds_rate)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embed(x)  # Shape: [B, embed_dim, L]
        x = self.pool(x)  # Downsample by averaging
        return x


class XCiT1d(nn.Module):
    """A 1D implementation of the XCiT architecture.

    Args:
        input_channels (int): Number of 1D input channels.
        n_features (int): Number of output features/classes.
        xcit_version (str): Version of XCiT model to use (e.g., 'nano_12_p16_224').
        drop_path_rate (float): Drop path rate for training.
        drop_rate (float): Dropout rate for training.
        ds_method (str): Downsampling method ('downsample' or 'chunk').
        ds_rate (int): Downsampling rate (e.g., 2 for downsampling by a factor of 2).
    """

    def __init__(
        self,
        input_channels: int,
        n_features: int,
        xcit_version: str = "nano_12_p16_224",
        drop_path_rate: float = 0.0,
        drop_rate: float = 0.3,
        ds_method: str = "downsample",
        ds_rate: int = 2,
    ):
        super().__init__()

        # Ensure the model name is correct
        model_name = (
            f"xcit_{xcit_version}"
            if not xcit_version.startswith("xcit_")
            else xcit_version
        )

        # Create the backbone model
        self.backbone = timm.create_model(
            model_name,
            pretrained=False,
            num_classes=n_features,
            in_chans=input_channels,
            drop_path_rate=drop_path_rate,
            drop_rate=drop_rate,
        )

        # Number of features from the backbone
        W = self.backbone.num_features

        # Include the grouper Conv1d layer
        self.grouper = nn.Conv1d(W, n_features, kernel_size=1)

        # Replace the patch embedding with a 1D version
        if ds_method == "downsample":
            self.backbone.patch_embed = ConvDownSampler(input_channels, W, ds_rate)
        elif ds_method == "chunk":
            self.backbone.patch_embed = Chunker(input_channels, W, ds_rate)
        else:
            raise ValueError(
                f"{ds_method} is not a supported downsampling method; currently 'downsample' and 'chunk' are supported"
            )

        # Replace the classifier head with an identity layer (since we use self.grouper)
        self.backbone.head = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mdl = self.backbone
        B = x.shape[0]

        # Patch embedding
        x = self.backbone.patch_embed(x)  # Shape: [B, C, L]

        # Define H and W for 1D data
        Hp, Wp = x.shape[-1], 1  # Height is sequence length, Width is 1

        # Obtain positional encoding
        pos_encoding = mdl.pos_embed(B, Hp, Wp).reshape(B, -1, Hp).permute(0, 2, 1)

        # Add positional encoding
        x = x.transpose(1, 2) + pos_encoding  # Shape: [B, Hp, C]

        # Apply transformer blocks
        for blk in mdl.blocks:
            x = blk(x, Hp, Wp)

        # Classification token
        cls_tokens = mdl.cls_token.expand(B, -1, -1)  # Shape: [B, 1, C]
        x = torch.cat((cls_tokens, x), dim=1)  # Shape: [B, Hp+1, C]

        # Apply class attention blocks
        for blk in mdl.cls_attn_blocks:
            x = blk(x)

        # Layer normalization
        x = mdl.norm(x)  # Shape: [B, Hp+1, C]

        # Apply the grouper Conv1d to the classification token
        # Extract the classification token (first token)
        cls_token = x[:, 0, :]  # Shape: [B, C]

        # Reshape for Conv1d: [B, C, 1]
        cls_token = cls_token.unsqueeze(-1)  # Shape: [B, C, 1]

        # Apply the grouper Conv1d
        x = self.grouper(cls_token).squeeze(-1)  # Shape: [B, n_features]

        # If x is 1D (batch size 1), ensure it has the correct shape
        if x.dim() == 1:
            x = x.unsqueeze(0)

        return x


class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None, reduction="mean", ignore_index=-100):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha  # Can be a scalar or a tensor of shape [num_classes]
        self.reduction = reduction
        self.ignore_index = ignore_index

    def forward(self, inputs, targets):
        log_probs = F.log_softmax(inputs, dim=1)
        ce_loss = F.nll_loss(
            log_probs,
            targets,
            weight=self.alpha,
            reduction="none",
            ignore_index=self.ignore_index,
        )
        probs = torch.exp(-ce_loss)
        focal_loss = ((1 - probs) ** self.gamma) * ce_loss

        if self.reduction == "mean":
            return focal_loss.mean()
        elif self.reduction == "sum":
            return focal_loss.sum()
        else:
            return focal_loss


class XCiTClassifier(pl.LightningModule):
    def __init__(
        self,
        input_channels: int,
        num_classes: int,
        xcit_version: str = "tiny_12_p16_224",
        ds_method: str = "downsample",
        ds_rate: int = 16,
        learning_rate: float = 1e-3,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.model = XCiT1d(
            input_channels=input_channels,
            n_features=num_classes,
            xcit_version=xcit_version,
            ds_method=ds_method,
            ds_rate=ds_rate,
        )
        self.learning_rate = learning_rate
        # self.criterion = nn.CrossEntropyLoss()
        self.criterion = FocalLoss(gamma=2.0, alpha=None, reduction="mean")

        # For logging
        self.train_losses = []
        self.val_losses = []
        self.val_accuracies = []

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def training_step(self, batch, batch_idx) -> torch.Tensor:
        x, y = batch
        x = x.float()
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        self.log("train_loss", loss, on_step=False, on_epoch=True)
        self.log("train_acc", acc, on_step=False, on_epoch=True)

        self.train_losses.append(loss.item())
        return loss

    def validation_step(self, batch, batch_idx) -> None:
        x, y = batch
        x = x.float()
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        self.val_losses.append(loss.item())
        self.val_accuracies.append(acc.item())

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.trainer.max_epochs
        )
        return [optimizer], [lr_scheduler]

In [None]:
%pip install torchinfo

In [None]:
from torchinfo import summary

model = XCiTClassifier(
    input_channels=2,
    num_classes=num_classes,
)
summary(model)

## Train the Model

Using the [Pytorch Lightning Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html), we can train our model for modulation recognition on IQ dataset.

In [None]:
import torch
import pytorch_lightning as pl

num_epochs = 1

trainer = pl.Trainer(
    limit_train_batches=50,
    limit_val_batches=5,
    max_epochs=num_epochs,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
)

trainer.fit(model, train_dataloader)

## Test the Model

Now that we've trained the model, we can test its predictions

In [None]:
from torchsig.datasets.datasets import TorchSigIterableDataset, StaticTorchSigDataset
from torchsig.utils.writer import DatasetCreator, default_collate_fn
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
torch.cuda.empty_cache()

test_dataset_size = 10

dataset = TorchSigIterableDataset(
    metadata=dataset_metadata,
    transforms=transforms,
    target_labels=None,
)  # ["class_index"])

dataloader = WorkerSeedingDataLoader(dataset, collate_fn=lambda x: x)
dataloader.seed(1234)

dc = DatasetCreator(
    dataloader=dataloader, root=f"{root}/test", overwrite=True, dataset_length=100
)
dc.create()

test_dataset = StaticTorchSigDataset(root=f"{root}/test", target_labels=["class_index"])

data, class_index = test_dataset[0]
print(f"Data shape: {data.shape}")
print(f"Targets: {class_index}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data, class_index = test_dataset[0]
# move to model to the same device as the data
model.to(device)
# turn the model into evaluation mode
model.eval()
with torch.no_grad():  # do not update model weights
    # convert to tensor and add a batch dimension
    data = torch.from_numpy(data).to(device).unsqueeze(dim=0)
    # have model predict data
    # returns a probability the data is each signal class
    pred = model(data)
    # print(pred) # if you want to see the list of probabilities

    # choose the class with highest confidence
    predicted_class = torch.argmax(pred).cpu().numpy()
    print(f"Predicted = {predicted_class} ({class_list[predicted_class]})")
    print(f"Actual = {class_index} ({class_list[class_index]})")

In [None]:
# We can do this over the whole test dataset to check to accurarcy of our model
predictions = []
true_classes = []
num_correct = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for sample in test_dataset:
    data, actual_class = sample
    model.to(device)
    model.eval()
    with torch.no_grad():
        data = torch.from_numpy(data).to(device).unsqueeze(dim=0)
        pred = model(data)
        predicted_class = torch.argmax(pred).cpu().numpy()
        predictions.append(predicted_class)
        true_classes.append(actual_class)
        if predicted_class == actual_class:
            num_correct += 1

# try increasing num_epochs or train dataset size to increase accuracy
print(f"Correct Predictions = {num_correct}")
print(f"Percent Correct = {num_correct / len(test_dataset)}%")

In [None]:
%pip install scikit-learn

In [None]:
# We can also plot a confusion matrix using Sklearn's confusion matrix tool
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

matrix = confusion_matrix(
    true_classes, predictions, labels=list(range(len(class_list)))
)
disp = ConfusionMatrixDisplay(matrix, display_labels=class_list)
disp.plot()