# Example 13 - Using Our Pre-Trained Models
This notebook shows how users can use our pretrained models for both Narrowband and Wideband.

----

## Model Information
We have two Narrowband Models, and two Wideband models. See the below table for info.

| Name | Filename  | Dataset | Description | Input Data | Input Shape |
| ---- | --------  | ------- | ----------- | ---------- | ----------- |
| ConVit | convit_narrowband.pth | Narrowband | test | IQ | (2, 4096) |
| XCiT | xcit_narrowband.ckpt | Narrowband | test | IQ | (2, 4096) |
| YOLO Detect | yolo_detect.pt | Wideband | YOLO model trained for energy detection (drawing bounding boxes). | Spectrogram | (1024, 1024) |
| YOLO Classify | yolo_classify.pt | Wideband | YOLO model that can classify signals into their respective signal families. | Spectrogra, | (1024, 1024) |

## Import Libraries

In [None]:
# TorchSig
from torchsig.datasets.signal_classes import torchsig_signals
from torchsig.transforms.target_transforms import DescToClassIndex, DescToBBoxFamilyDict, DescToBBoxDict
from torchsig.transforms.transforms import (
    RandomPhaseShift,
    Normalize,
    Compose,  
    ComplexTo2D,
    Spectrogram,
    SpectrogramImage
)
from torchsig.datasets.modulations import ModulationsDataset
from torchsig.datasets.wideband import WidebandModulationsDataset
from torchsig.utils.writer import DatasetCreator
from torchsig.datasets.torchsig_narrowband import TorchSigNarrowband
from torchsig.datasets.torchsig_wideband import TorchSigWideband

# Third Party
import torch
from torchinfo import summary
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import cv2

# Built-In
import os

print("Imports Done.")

# Narrowband Models
We have two pretrained models for Narrowband: ConVit and XCiT.

## Create Test Narrowband Dataset
First create and save the dataset to disk. Then load it in.

In [None]:
device = torch.device(0)
root = "./datasets/13_example"
num_workers = 8
seed = 1234567890
class_list = torchsig_signals.class_list
num_classes = len(class_list)

ds = ModulationsDataset(
    level = 2,
    num_samples = num_classes,
    num_iq_samples = 4096,
    eb_no = False,
    use_class_idx = True,
    include_snr = True
)

os.makedirs(root, exist_ok=True)

creator = DatasetCreator(
    ds,
    seed = seed,
    path = f"{root}/narrowband_impaired_val",
    num_workers = num_workers,
)

creator.create()

In [None]:
transform = Compose([
    RandomPhaseShift(phase_offset=(-1, 1)),
    Normalize(norm=np.inf),
    ComplexTo2D(),
])

target_transform = DescToClassIndex(class_list=class_list)


test_narrowband = TorchSigNarrowband(
    root,
    train = False,
    impaired = True,
    transform = transform,
    target_transform = target_transform,
    use_signal_data = True,
)

test_data_numpy, test_target = test_narrowband[0]
figure = plt.figure(figsize=(12, 6))
plt.subplot(1, 1, 1)
plt.plot(test_data_numpy[0][:100])
plt.plot(test_data_numpy[1][:100])
plt.xticks([])
plt.yticks([])
plt.title(f"Class Index: {test_target}, Class Name: {class_list[test_target]}")
print(f"Class Name: {class_list[test_target]}")

# convert data as tensor, and put on same device as model, add batch dimension
test_data = torch.from_numpy(test_data_numpy).to(device).unsqueeze(0)
test_data = test_data.to(torch.float32)
print(f"Data Shape: {test_data.shape}")

## ConVit Model
We can download the model from our hosted servers on via the release package (see release v0.6.1 attached files). The file will be stored under the examples/ directory. The following lines check to see if the file exists, and if not, it is downloaded. Therefore if you have a partial download or want a fresh copy of the file you will need to delete it manually before rerunning the notebook.

In [None]:
if (not(os.path.isfile('convit_narrowband.pth'))):
    download_command = 'curl -L -o "convit_narrowband.pth" "https://bucket.ltsnet.net/torchsig/models/convit_narrowband.pth"'
    os.system(download_command)

In [None]:
from torchsig.models.iq_models.convit import ConVit1DLightning

convit_model = ConVit1DLightning.load_from_checkpoint("convit_narrowband.pth")
summary(convit_model)

# set model in evaluation mode
convit_model.eval()

In [None]:
# test data with model
test_data
pred = convit_model(test_data)
result = torch.argmax(pred, dim=1).cpu().item()

# compare results
print(f"Model Prediction = {result} | {class_list[result]}")
print(f"Actual = {test_target} | {class_list[test_target]}")

## XCiT Model
We can download the model from our hosted servers on via the release package (see release v0.6.1 attached files). The file will be stored under the examples/ directory. The following lines check to see if the file exists, and if not, it is downloaded. Therefore if you have a partial download or want a fresh copy of the file you will need to delete it manually before rerunning the notebook.

In [None]:
if (not(os.path.isfile('xcit_narrowband.ckpt'))):
    download_command = 'curl -L -o "xcit_narrowband.ckpt" "https://bucket.ltsnet.net/torchsig/models/xcit_narrowband.ckpt"'
    os.system(download_command)

In [None]:
from torchsig.models.iq_models.xcit import XCiTClassifier

xcit_model = XCiTClassifier.load_from_checkpoint("xcit_narrowband.ckpt")
summary(xcit_model)

xcit_model.eval()

In [None]:
# test data with model
pred = xcit_model(test_data)
result = torch.argmax(pred, dim=1).cpu().item()

# compare results
print(f"Model Prediction = {result} | {class_list[result]}")
print(f"Actual = {test_target} | {class_list[test_target]}")

# Wideband Models
We have two YOLO models trained on wideband spectrograms. `yolo_detect.pt` performs energy detection, while `yolo_classify.pt` performs family classification.

## Create Test Wideband Dataset

In [None]:
root = "./datasets/13_example"
overlap_prob = 0.1
num_workers = 8
batch_size = 8
seed = 1234567891
class_list = torchsig_signals.class_list
num_classes = len(class_list)
fft_size = 512

ds = WidebandModulationsDataset(
    level = 2,
    num_samples = num_classes,
    num_iq_samples = 512**2,
    seed = seed,
    overlap_prob = overlap_prob
)
os.makedirs(root, exist_ok=True)

creator = DatasetCreator(
    ds,
    seed = seed,
    path = f"{root}/wideband_impaired_val",
    num_workers = num_workers,
)

creator.create()


## YOLO Detection Model

Below, we set up the dataset to transform the data into spectrograms, and labels as bounding box information.

In [None]:
transform = Compose([
    Spectrogram(nperseg=512, noverlap=0, nfft=512, mode='psd'),
    Normalize(norm=np.inf, flatten=True),
    SpectrogramImage(), 
])

target_transform = DescToBBoxDict(
    class_list = class_list
)

test_detect_wideband = TorchSigWideband(
    root,
    train = False,
    impaired = True,
    transform = transform,
    target_transform = target_transform
)

test_data, test_target = test_detect_wideband[0]
print(f"Spectrogram: {test_data.shape}")
print(f"Class Name Indicies: {test_target['labels']}")
print(f"Bounding Boxes (xmin, ymin, width, height):\n {test_target['boxes']}")

full_width, full_height, _ = test_data.shape
figure = plt.figure(figsize=(9, 9))
ax = plt.subplot(1, 1, 1)
ax.imshow(test_data)
for i,l in enumerate(test_target['labels']):
    norm_xstart, norm_ystart, norm_width, norm_height = test_target['boxes'][i]
    rect = patches.Rectangle(
        (norm_xstart*full_width, norm_ystart*full_height),
        norm_width * full_width,
        norm_height * full_height,
        linewidth = 2,
        edgecolor = "r",
        facecolor = "none"
    )
    ax.add_patch(rect)

We can download the model from our hosted servers on via the release package (see release v0.6.1 attached files). The file will be stored under the examples/ directory. The following lines check to see if the file exists, and if not, it is downloaded. Therefore if you have a partial download or want a fresh copy of the file you will need to delete it manually before rerunning the notebook.

In [None]:
if (not(os.path.isfile('yolo_detect.pt'))):
    download_command = 'curl -L -o "yolo_detect.pt" "https://bucket.ltsnet.net/torchsig/models/yolo_detect.pt"'
    os.system(download_command)

In [None]:
from ultralytics import YOLO

yolo_detect_model = YOLO("yolo_detect.pt")
summary(yolo_detect_model)

yolo_detect_model.eval()

In [None]:
# test data with model
results = yolo_detect_model.predict(test_data, save=True, imgsz=512, conf=0.5)

In [None]:
%matplotlib inline

In [None]:
# compare results
results_dir = results[0].save_dir
imgpath = os.path.join(results_dir, "image" + str(0) + ".jpg")

figure, ax = plt.subplots(1, 1, figsize=(9, 9))
figure.suptitle("test_data")

# plot predicted bonding boxes
img = cv2.imread(imgpath)
ax.imshow(img)

# plot actual bounding boxes
full_width, full_height, _ = test_data.shape
for i,l in enumerate(test_target['labels']):
    norm_xstart, norm_ystart, norm_width, norm_height = test_target['boxes'][i]
    rect = patches.Rectangle(
        (norm_xstart*full_width, norm_ystart*full_height),
        norm_width * full_width,
        norm_height * full_height,
        linewidth = 2,
        edgecolor = "b",
        facecolor = "none"
    )
    ax.add_patch(rect)

ax.legend(["Actual"])

## YOLO Family Classification Model
Below, we set up the dataset to transform the data into spectrograms, and labels as family name and bounding box information.

In [None]:

target_transform = DescToBBoxFamilyDict(
    class_family_dict = torchsig_signals.family_dict
)
family_list = target_transform.family_list

test_detect_wideband = TorchSigWideband(
    root,
    train = False,
    impaired = True,
    transform = transform,
    target_transform = target_transform
)

test_data, test_target = test_detect_wideband[0]
print(f"Spectrogram: {test_data.shape}")
print(f"Family Class Indicies: {test_target['labels']}")
print(f"Bounding Boxes (xcenter, ycenter, width, height):\n {test_target['boxes']}")

full_width, full_height, _ = test_data.shape
figure = plt.figure(figsize=(9, 9))
ax = plt.subplot(1, 1, 1)
ax.imshow(test_data)
for i,l in enumerate(test_target['labels']):
    norm_xcenter, norm_ycenter, norm_width, norm_height = test_target['boxes'][i]
    width = norm_width * full_width
    height = norm_height * full_height
    xcenter = norm_xcenter * full_width
    ycenter = norm_ycenter * full_height
    xstart = xcenter - (width / 2)
    ystart = ycenter - (height / 2)
    rect = patches.Rectangle(
        (xstart, ystart),
        width,
        height,
        linewidth = 2,
        edgecolor = "r",
        facecolor = "none"
    )
    ax.add_patch(rect)
    ax.text(xcenter - 1, ycenter - 1, f"{family_list[l]}", backgroundcolor = 'gray', color = 'b', fontsize='small')

We can download the model from our hosted servers on via the release package (see release v0.6.1 attached files). The file will be stored under the examples/ directory. The following lines check to see if the file exists, and if not, it is downloaded. Therefore if you have a partial download or want a fresh copy of the file you will need to delete it manually before rerunning the notebook.

In [None]:
if (not(os.path.isfile('yolo_classify.pt'))):
    download_command = 'curl -L -o "yolo_classify.pt" "https://bucket.ltsnet.net/torchsig/models/yolo_classify.pt"'
    os.system(download_command)

In [None]:
from ultralytics import YOLO

yolo_detect_model = YOLO("yolo_classify.pt")
summary(yolo_detect_model)

yolo_detect_model.eval()

In [None]:
# test data with model
results = yolo_detect_model.predict(test_data, save=True, imgsz=512, conf=0.5)

In [None]:
# compare results
results_dir = results[0].save_dir
imgpath = os.path.join(results_dir, "image" + str(0) + ".jpg")

figure, ax = plt.subplots(1, 1, figsize=(9, 9))
figure.suptitle("test_data")

# plot predicted bonding boxes
img = cv2.imread(imgpath)
ax.imshow(img)

full_width, full_height, _ = test_data.shape
for i,l in enumerate(test_target['labels']):
    norm_xcenter, norm_ycenter, norm_width, norm_height = test_target['boxes'][i]
    width = norm_width * full_width
    height = norm_height * full_height
    xcenter = norm_xcenter * full_width
    ycenter = norm_ycenter * full_height
    xstart = xcenter - (width / 2)
    ystart = ycenter - (height / 2)
    rect = patches.Rectangle(
        (xstart, ystart),
        width,
        height,
        linewidth = 2,
        edgecolor = "r",
        facecolor = "none"
    )
    ax.add_patch(rect)
    ax.text(xcenter - 1, ycenter - 1, f"{family_list[l]}", backgroundcolor = 'gray', color = 'b', fontsize='small')

ax.legend(["Actual"], bbox_to_anchor = (0, 1.02, 1, 0.2), loc ='lower left', mode = 'expand')