# Example 07 - TorchSig Narrowband with YOLOv8 Classifier
This notebook showcases using the TorchSig Narrowband dataset to train a YOLOv8 classification model.

---

## Import Libraries

In [1]:
# Packages for Training
from torchsig.utils.yolo_classify import *
from torchsig.utils.classify_transforms import real_imag_vstacked_cwt_image, complex_iq_to_heatmap
import yaml

In [2]:
# Packages for testing/inference
from torchsig.datasets.modulations import ModulationsDataset
from torchsig.datasets.signal_classes import torchsig_signals
from torchsig.transforms.target_transforms import DescToFamilyName
from torchsig.transforms.transforms import Spectrogram, SpectrogramImage, Normalize, Compose, Identity
from ultralytics import YOLO
import torch
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import os

## Prepare YOLO classificatoin trainer and Model
Datasets are generated on the fly in a way that is Ultralytics YOLO compatible. See [Ultralytics: Train Custom Data - Organize Directories](https://docs.ultralytics.com/yolov5/tutorials/train_custom_data/#23-organize-directories) to learn more. 

Additionally, we create a yaml file for dataset configuration. See "classify.yaml" in Torchsig Examples.

Download desired YOLO model from [Ultralytics Models](https://docs.ultralytics.com/models/). We will use YOLOv8, specifically `yolov8n-cls.pt`

---

In [3]:
config_path = '07_yolo.yaml'
with open(config_path, 'r') as file:
    config = yaml.safe_load(file)

overrides = config['overrides']

### Explanation of the `overrides` Dictionary

The `overrides` dictionary is used to customize the settings for the Ultralytics YOLO trainer by specifying specific values that override the default configurations. The dictionary is imported from `classify.yaml`. However, you can customize in the notebook. 

Example:

```python
overrides = {'model': 'yolov8n-cls.pt', 'epochs': 100, 'data': 'classify.yaml', 'device': 0, 'imgsz': 64}
```
A .yaml is necessary for training. Look at `classify.yaml` in the examples directory. It will contain the path to your torchsig data.

### Explanation of `image_transform` function
`YoloClassifyTrainer` allows you to pass in any transform that takes in complex I/Q and outputs an image for training. Some example transforms can be found in torchsig.utils.classify_transforms. If nothing is passed, it will default to spectrogram images. It is important to update `overrides` so that your imgsz matches output.

In [4]:
# define dataset variables for yaml file
config_name = "07_yolo.yaml"
family_list = ["ask", "fsk", "ofdm", "pam", "psk", "qam"]
family_dict = {v: k for v, k in enumerate(family_list)}
classes = {v: k for v, k in enumerate(torchsig_signals.class_list)}
num_classes = len(classes)
yolo_root = "./wideband/" # train/val images (relative to './datasets``

# define overrides
overrides = dict(
    model = "yolov8n-cls.pt",
    project = "yolo",
    name = "07_example",
    epochs = 5,
    imgsz = 512,
    data = config_name,
    device = 0 if torch.cuda.is_available() else "cpu",
    batch = 32,
    workers = 8

)

# create yaml file for trainer
yolo_config = dict(
    overrides = overrides,
    train = yolo_root,
    val = yolo_root,
    level = 2,
    include_snr = False,
    num_samples = 530,
    nc = num_classes,
    names = classes,
    family = False, # Determines if you are classify all 50+ classes or modulation family (see Classes below)
    families = family_dict
)

with open(config_name, 'w+') as file:
    yaml.dump(yolo_config, file, default_flow_style=False)

print(f"Creating experiment -> {overrides['name']}")

Creating experiment -> 07_example


### Build YoloClassifyTrainer
This will instantiate the YOLO classification trainer with overrides specified above.

In [5]:
trainer = YoloClassifyTrainer(overrides=overrides, image_transform=None)

Ultralytics 8.3.3 🚀 Python-3.10.12 torch-2.4.1+cu121 CUDA:0 (NVIDIA GeForce RTX 4090 Laptop GPU, 15981MiB)
[34m[1mengine/trainer: [0mtask=classify, mode=train, model=yolov8n-cls.pt, data=07_yolo.yaml, epochs=5, time=None, patience=100, batch=32, imgsz=512, save=True, save_period=-1, cache=False, device=0, workers=8, project=yolo, name=07_example2, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, vid_stride=1, stream_buffer=False, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, embed=None, show=False, save_frames=False, save_txt=False, save_conf=False, save_crop=False, show_labels=

### Then begin training

In [6]:
trainer.train()

[34m[1mTensorBoard: [0mStart with 'tensorboard --logdir yolo/07_example2', view at http://localhost:6006/
Overriding model.yaml nc=1000 with nc=62

                   from  n    params  module                                       arguments                     
  0                  -1  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]                 
  1                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]                
  2                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]             
  3                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]                
  4                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]             
  5                  -1  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]               
  6                  -1  2    197632  ultralytics.n

ValueError: num_samples should be a positive integer value, but got num_samples=0

CWSpikeDataset._generate_samples called
SPIKEDataset
CWSpikeDataset._generate_samples called
SPIKEDataset
CWSpikeDataset._generate_samples called
SPIKEDataset
CWSpikeDataset._generate_samples called
SPIKEDataset
CWSpikeDataset._generate_samples called
SPIKEDataset
CWSpikeDataset._generate_samples called
SPIKEDataset
CWSpikeDataset._generate_samples called
SPIKEDataset
CWSpikeDataset._generate_samples called
SPIKEDataset


### Instantiate Test Dataset

Uses Torchsig's `ModulationsDataset` to generate a narrowband classification dataset. 

In [None]:
# Determine whether to map descriptions to family names
if config['family']:
    target_transform = CP([DescToFamilyName()])
else:
    target_transform = None

transform = Compose([
    Spectrogram(nperseg=overrides['imgsz'], noverlap=0, nfft=overrides['imgsz'], mode='psd'),
    Normalize(norm=np.inf, flatten=True),
    SpectrogramImage(), 
    ])

class_list = [item[1] for item in config['names'].items()]

dataset = ModulationsDataset(
    classes=class_list,
    use_class_idx=False,
    level=config['level'],
    num_iq_samples=overrides['imgsz']**2,
    num_samples=int(config['nc'] * 10),
    include_snr=config['include_snr'],
    transform=transform,
    target_transform=target_transform
)

# Retrieve a sample and print out information
idx = np.random.randint(len(dataset))
data, label = dataset[idx]
print("Dataset length: {}".format(len(dataset)))
print("Data shape: {}".format(data.shape))

samples = []
labels = []
for i in range(10):
    idx = np.random.randint(len(dataset))
    sample, label = dataset[idx]
    samples.append(sample)
    labels.append(label)

### Predictions / Inference
The following cells show you how to load the 'best.pt' weights from your training for prediction

In [None]:
%matplotlib inline

In [None]:
model = YOLO(trainer.best) #The model will remember the configuration from training
results = model.predict(samples, save=True)

In [None]:
# Plot prediction results
rows = 3
cols = 3
fig = plt.figure(figsize=(15, 15)) 
results_dir = results[0].save_dir

for y, result in enumerate(results[:9]):
    imgpath = os.path.join(results_dir, "image" + str(y) + ".jpg")
    fig.add_subplot(rows, cols, y + 1) 
    img = cv2.imread(imgpath)
    plt.imshow(img)
    plt.title("Truth: " + str(labels[y]), fontsize='large', loc='left')
