# Example 02 - TorchSig Narrowband Classifier
This notebook walks through a simple example of how to use the clean TorchSig Narrowband Dataset and Trainer. You can train from scratch or load a pre-trained supported model, and evaluate the trained network's performance. Note that the experiment and the results herein are not to be interpreted with any significant value but rather serve simply as a practical example of how the `torchsig` dataset and tools can be used and integrated within a typical [PyTorch](https://pytorch.org/) and/or [PyTorch Lightning](https://www.pytorchlightning.ai/) workflow.

----

In [6]:
# TorchSig imports
from torchsig.transforms.target_transforms import DescToClassIndex
from torchsig.transforms.transforms import (
    RandomPhaseShift,
    Normalize,
    ComplexTo2D,
    Compose,
)
from torchsig.utils.narrowband_trainer import NarrowbandTrainer
from torchsig.datasets.torchsig_narrowband import TorchSigNarrowband
from torchsig.datasets.datamodules import NarrowbandDataModule
import numpy as np
import cv2
import os
import matplotlib.pyplot as plt

----
### Instantiate TorchSigNarrowband Dataset
Here, we instantiate the TorchSigNarrowband clean training dataset and the TorchSigNarrowband clean validation dataset. We demonstrate how to compose multiple TorchSig transforms together, using a data impairment with a random phase shift that uniformly samples a phase offset between -1 pi and +1 pi. The next transform normalizes the complex tensor, and the final transform converts the complex data to a real-valued tensor with the real and imaginary parts as two channels. We additionally provide a target transform that maps the `SignalMetadata` objects, that are part of `SignalData` objects, to a desired format for the model we will train. In this case, we use the `DescToClassIndex` target transform to map class names to their indices within an ordered class list. Finally, we sample from our datasets and print details in order to confirm functionality.

For more details on the TorchSigNarrowband dataset instantiations, please see `00_example_narrowband_dataset.ipynb`.

In [7]:
class_list = list(TorchSigNarrowband._idx_to_name_dict.values())
num_classes = len(class_list)

# Specify Transforms
transform = Compose(
    [
        RandomPhaseShift(phase_offset=(-1, 1)),
        Normalize(norm=np.inf),
        ComplexTo2D(),
    ]
)
target_transform = DescToClassIndex(class_list=class_list)

datamodule = NarrowbandDataModule(
    root='./datasets/narrowband_test_QA',
    qa=True,
    impaired=True,
    transform=transform,
    target_transform=target_transform,
    batch_size=32,
    num_workers=16,
)

---
### Instantiate and Initialize the NarrowbandTrainer with specified parameters.

    Args:
        model_name (str): Name of the model to use.
        num_epochs (int): Number of training epochs.
        batch_size (int): Batch size for training.
        num_workers (int): Number of workers for data loading.
        learning_rate (float): Learning rate for the optimizer.
        input_channels (int): Number of input channels into model.
        data_path (str): Path to the dataset.
        impaired (bool): Whether to use the impaired dataset.
        qa (bool): Whether to use QA configuration.
        checkpoint_path (str): Path to a checkpoint file to load the model weights.
        datamodule (LightningDataModule): Custom data module instance.


In [8]:
# Initialize the trainer with desired parameters
trainer = NarrowbandTrainer(
    model_name = 'xcit',
    num_epochs = 2,
    # batch_size = 32, # Uncomment if not passing in datamodule
    # num_workers = 16, # Uncomment if not passing in datamodule
    learning_rate = 1e-3,
    input_channels = 2,
    # data_path = '../datasets/narrowband_test_QA', # Uncomment if not passing in datamodule
    # impaired = True, # Uncomment if not passing in datamodule
    # qa = False # Uncomment if not passing in datamodule
    datamodule = datamodule,
    checkpoint_path = None # If loading checkpoint, add path here
)

Using custom datamodule provided.


In [9]:
# View all available models
print(trainer.available_models)

{'xcit': 'XCiTClassifier', 'inception': 'InceptionTime', 'MyNewModel': 'MyNewModel'}


---
### Train or Fine Tune your model.
    Can load any pytorchlightning checkpoint by providing checkpoint path above, otherwise with train specified model.

In [10]:
# Train the model
trainer.train()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


RuntimeError: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero.

---
### Validate model
    You can validate a model by loading its checkpoint in the intialization stage or after training.

In [None]:
trainer.validate()

In [None]:
# Train accuracy and loss plots
acc_plot = cv2.imread(trainer.acc_plot_path)
loss_plot = cv2.imread(trainer.loss_plot_path)

plots = [acc_plot, loss_plot]

fig = plt.figure(figsize=(21, 6))
r = 1
c = 3

for i in range(2):
    fig.add_subplot(r, c, i + 1)
    plt.imshow(plots[i])
    plt.axis('off') 

plt.show()

In [None]:
# confusion matrix
cm_plot = cv2.imread(trainer.cm_plot_path)
plt.imshow(cm_plot, aspect='auto')

---
### Predict with model
    You can make inferences/predictions with model by loading checkpoint in the intialization stage or after training.

#### Load Data
    You can load whatever data you wish, assuming it is a torch.Tensor.
    In this example, we will load an example from our validation set

    Data needs to be shape (batch_size, input_channels, data_length). You can use tensor.unsqueeze(dim=0) to add a batch dimension.

In [None]:
import torch
datamodule.prepare_data()
datamodule.setup("fit")

# Retrieve a sample and print out information to verify
idx = np.random.randint(len(datamodule.val))
data, label = datamodule.train[idx]
data = torch.tensor(data).float().unsqueeze(dim=0)
print("Dataset length: {}".format(len(datamodule.val)))
print("Data shape: {}".format(data.shape))
print("Label Index: {}".format(label))
print("Label Class: {}".format(TorchSigNarrowband.convert_idx_to_name(label)))

In [None]:
# Predict on new data (assuming `new_data` is a torch.Tensor)
predictions = trainer.predict(data)[0]
print(TorchSigNarrowband._idx_to_name_dict[predictions])