# Training a Model on Narrowband for Classification

This notebook demonstrates how to train a PyTorch model on the Narrowband dataset for modulation recognition.

---

In [1]:
# Variables
from torchsig.signals.signal_lists import TorchSigSignalLists
from torchsig.transforms.dataset_transforms import ComplexTo2D
from torchsig.transforms.target_transforms import ClassIndex

root = "./datasets/narrowband_classifier_example"
fft_size = 256
num_iq_samples_dataset = fft_size ** 2
class_list = TorchSigSignalLists.all_signals
num_classes = len(class_list)
num_samples_train = len(class_list) * 10 # roughly 10 samples per class
num_samples_val = len(class_list) * 2
impairment_level = 0
seed = 123456789

# ComplexTo2D turns a IQ array of complex values into a 2D array, with one channel for the real component, while the other is for the imaginary component
transforms = [ComplexTo2D()]
# ClassIndex turns our target labels into the index of the class according to class_list
target_transforms = [ClassIndex()]

## Create the Narrowband Dataset

In [None]:
from torchsig.datasets.dataset_metadata import NarrowbandMetadata
from torchsig.datasets.datamodules import NarrowbandDataModule

dataset_metadata = NarrowbandMetadata(
    num_iq_samples_dataset = num_iq_samples_dataset,
    fft_size = fft_size,
    impairment_level = impairment_level,
    class_list = class_list,
    seed = seed
)

narrowband_datamodule = NarrowbandDataModule(
    root = root,
    dataset_metadata = dataset_metadata,
    num_samples_train = num_samples_train,
    num_samples_val = num_samples_val,
    transforms = transforms,
    target_transforms = target_transforms,
    create_batch_size = 4,
    create_num_workers = 4,
    batch_size=4,
    num_workers=4,
)
narrowband_datamodule.prepare_data()
narrowband_datamodule.setup()

data, targets = narrowband_datamodule.train[0]
print(f"Data shape: {data.shape}")
print(f"Targets: {targets}")

## Create the Model

We use our own XCIT model code and utils, but this can be replaced with your own model arhcitecture in PyTorch, Ultralytics, timm, ect.

In [None]:
from torchsig.models import XCiTClassifier
from torchinfo import summary

model = XCiTClassifier(
    input_channels=2,
    num_classes=num_classes,
)
summary(model)

## Train the Model

Using the [Pytorch Lightning Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html), we can train our model for modulation recognition on Narrowband IQ dataset.

In [None]:
import torch
import pytorch_lightning as pl

num_epochs = 10

trainer = pl.Trainer(
    max_epochs = num_epochs,
    accelerator =  'gpu' if torch.cuda.is_available() else 'cpu',
    devices = 1
)
# print(trainer)

trainer.fit(model, narrowband_datamodule)

## Test the Model

Now that we've trained the model, we can test its predictions on a new dataset (not used in training).

In [None]:
from torchsig.datasets.narrowband import NewNarrowband, StaticNarrowband
from torchsig.utils.writer import DatasetCreator
import torch
torch.cuda.empty_cache()

test_dataset_size = 10

dataset_metadata_test = NarrowbandMetadata(
    num_iq_samples_dataset = num_iq_samples_dataset,
    fft_size = fft_size,
    impairment_level = impairment_level,
    class_list = class_list,
    num_samples=test_dataset_size,
    transforms=transforms,
    target_transforms=target_transforms,
    seed = 123456788 # different than train
)
# print(dataset_metadata_test)

dc = DatasetCreator(
    dataset = NewNarrowband(
        dataset_metadata = dataset_metadata_test,
    ),
    root = f"{root}/test",
    overwrite=True,
    batch_size=1,
    num_workers=1,
)
dc.create()

test_narrowband = StaticNarrowband(
    root = f"{root}/test",
    impaired = impairment_level > 0,
)


data, class_index = test_narrowband[0]
print(f"Data shape: {data.shape}")
print(f"Targets: {targets}")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data, class_index = test_narrowband[0]
# move to model to the same device as the data
model.to(device)
# turn the model into evaluation mode
model.eval()
with torch.no_grad(): # do not update model weights
    # convert to tensor and add a batch dimension
    data = torch.from_numpy(data).to(device).unsqueeze(dim=0)
    # have model predict data
    # returns a probability the data is each signal class
    pred = model(data)
    # print(pred) # if you want to see the list of probabilities

    # choose the class with highest confidence
    predicted_class = torch.argmax(pred).cpu().numpy()
    print(f"Predicted = {predicted_class} ({class_list[predicted_class]})")
    print(f"Actual = {class_index} ({class_list[class_index]})")

In [None]:
# We can do this over the whole test dataset to check to accurarcy of our model
num_correct = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for sample in test_narrowband:
    data, actual_class = sample
    model.to(device)
    model.eval()
    with torch.no_grad():
        data = torch.from_numpy(data).to(device).unsqueeze(dim=0)
        pred = model(data)
        predicted_class = torch.argmax(pred).cpu().numpy()
        if predicted_class == actual_class:
            num_correct += 1

# try increasing num_epochs or train dataset size to increase accuracy
print(f"Correct Predictions = {num_correct}")
print(f"Percent Correct = {num_correct / len(test_narrowband)}%")