# Example 06 - TorchSigWideband with YOLOv8 Detector
This notebook showcases using the Wideband dataset to train a YOLOv8 model.

---

## Import Libraries

In [None]:
# Packages Imports for Training
from torchsig.utils.yolo_train import *
from datetime import datetime
import yaml

In [None]:
# Package Imports for Testing/Inference
from torchsig.datasets.datamodules import WidebandDataModule
from torchsig.datasets.signal_classes import torchsig_signals
from torchsig.transforms.transforms import Spectrogram, SpectrogramImage, Normalize, Compose, Identity
from torchsig.transforms.target_transforms import DescToBBoxFamilyDict
from ultralytics import YOLO
import torch
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import os

-----------------------------
## Check or Generate the Wideband Dataset
To generate the TorchSigWideband dataset, several parameters are given to the imported `WidebandDataModule` class. These paramters are:
- `root` ~ A string to specify the root directory of where to generate and/or read an existing TorchSigWideband dataset
- `train` ~ A boolean to specify if the TorchSigWideband dataset should be the training (True) or validation (False) sets
- `qa` - A boolean to specify whether to generate a small subset of Wideband (True), or the full dataset (False), default is True
- `impaired` ~ A boolean to specify if the TorchSigWideband dataset should be the clean version or the impaired version
- `transform` ~ Optionally, pass in any data transforms here if the dataset will be used in an ML training pipeline. Note: these transforms are not called during the dataset generation. The static saved dataset will always be in IQ format. The transform is only called when retrieving data examples.
- `target_transform` ~ Optionally, pass in any target transforms here if the dataset will be used in an ML training pipeline. Note: these target transforms are not called during the dataset generation. The static saved dataset will always be saved as tuples in the LMDB dataset. The target transform is only called when retrieving data examples.

A combination of the `train` and the `impaired` booleans determines which of the four (4) distinct TorchSigWideband datasets will be instantiated:
| `impaired` | `qa` | Result |
| ---------- | ---- | ------- |
| `False` | `False` | Clean datasets of train=250k examples and val=25k examples |
| `False` | `True` | Clean datasets of train=250 examples and val=250 examples |
| `True` | `False` | Impaired datasets of train=250k examples and val=25k examples |
| `True` | `True` | Impaired datasets of train=250 examples and val=250 examples |

The final option of the impaired validation set is the dataset to be used when reporting any results with the official TorchSigWideband dataset.

In [None]:
# Generate TorchSigWideband DataModule
# Note: the qa datasets are intended for illustrative code and spot checks. 
# Do not expect significant model training results with these small data subsets.
root = "./datasets/wideband"
impaired = True
qa = True
fft_size = 512
num_classes = len(torchsig_signals.class_list)
batch_size = 1

transform = Compose([    
])

target_transform = Compose([
    DescToBBoxFamilyDict()
])

datamodule = WidebandDataModule(
    root=root,
    impaired=impaired,
    qa=qa,
    fft_size=fft_size,
    num_classes=num_classes,
    transform=transform,
    target_transform=target_transform,
    batch_size=batch_size
)

In [None]:
datamodule.prepare_data()
datamodule.setup("fit")

wideband_train = datamodule.train

# Retrieve a sample and print out information
idx = np.random.randint(len(wideband_train))
data, label = wideband_train[idx]
print("Dataset length: {}".format(len(wideband_train)))
print("Data shape: {}".format(data.shape))
print("Label: {}".format(label))

## Prepare YOLO trainer and Model
Next, the datasets are rewritten to disk that is Ultralytics YOLO compatible. See [Ultralytics: Train Custom Data - Organize Directories](https://docs.ultralytics.com/yolov5/tutorials/train_custom_data/#23-organize-directories) to learn more. 

Additionally, create a yaml file for dataset configuration. See [Ultralytics: Train Custom Data - Create dataset.yaml](https://docs.ultralytics.com/yolov5/tutorials/train_custom_data/#21-create-datasetyaml)

Download desired YOLO model from [Ultralytics Models](https://docs.ultralytics.com/models/). We will use YOLOv8, specifically `yolov8n.pt`

---

### Explanation of the `overrides` Dictionary

The `overrides` dictionary is used to customize the settings for the Ultralytics YOLO trainer by specifying specific values that override the default configurations. The dictionary is imported from `wbdata.yaml`. However, you can customize in the notebook. See [Ultralytics Train Settings](https://docs.ultralytics.com/modes/train/#train-settings) to learn more.

Example:

```python
overrides = {'model': 'yolov8n.pt', 'epochs': 100, 'data': 'wbdata.yaml', 'device': 0, 'imgsz': 512, 'single_cls': True}
```
A .yaml is necessary for training. Look at `06_yolo.yaml` in the examples directory. It will contain the path to your torchsig data.


### Dataset Location Warning

There must exist a datasets directory at `/path/to/torchsig/datasets`.

This example assumes that you have generated `train` and `val` lmdb wideband datasets at `./datasets/wideband/`

You can also specify an absolute path to your dataset in `06_yolo.yaml`.

In [None]:
# define dataset variables for yaml file
config_name = "06_yolo.yaml"
class_list = ["ask", "fsk", "ofdm", "pam", "psk", "qam"]
classes = {v: k for v, k in enumerate(class_list)}
num_classes = len(classes)
yolo_root = "./wideband/" # train/val images (relative to './datasets``

# define overrides
# Note: You can change use of GPU(s) or CPU by overriding the device
# GPU: device=0 or device=0,1
# CPU: device="cpu"
overrides = dict(
    model = "yolov8n.pt",
    project = "yolo",
    name = "06_example",
    epochs = 10,
    imgsz = 512,
    data = config_name,
    device = 0 if torch.cuda.is_available() else "cpu",
    single_cls = True,
    batch = 32,
    workers = 8

)

# create yaml file for trainer
yolo_config = dict(
    overrides = overrides,
    train = yolo_root,
    val = yolo_root,
    nc = num_classes,
    names = classes
)

with open(config_name, 'w+') as file:
    yaml.dump(yolo_config, file, default_flow_style=False)

print(f"Creating experiment -> {overrides['name']}")
    

In [None]:
trainer = Yolo_Trainer(overrides=overrides)

## Train
Train YOLO. See [Ultralytics Train](https://docs.ultralytics.com/modes/train/#train-settings) for training hyperparameter options.

---

In [None]:
trainer.train()

## Evaluation
Check model performance from training. From here, you can use the trained model to test on prepared data (numpy image arrays of spectrograms)

Will load example from Torchsig

model_path is path to best.pt from your training session. Path is printed at the end of training.

---

## Generate and Instantiate Wideband Test Dataset
After generating the Wideband dataset (see `03_example_widebandsig_dataset.ipynb`), we can instantiate it with the needed transforms. Change `root` to test dataset path.

---

In [None]:
test_path = './datasets/wideband_test' #Should differ from your training dataset

transform = Compose([
    Spectrogram(nperseg=512, noverlap=0, nfft=512, mode='psd'),
    Normalize(norm=np.inf, flatten=True),
    SpectrogramImage(), 
    ])

test_data = WidebandDataModule(
    root=test_path,
    impaired=impaired,
    qa=qa,
    fft_size=fft_size,
    num_classes=num_classes,
    transform=transform,
    target_transform=None,
    batch_size=batch_size
)

test_data.prepare_data()
test_data.setup("fit")

wideband_test = test_data.train

# Retrieve a sample and print out information
idx = np.random.randint(len(wideband_test))
data, label = wideband_test[idx]
print("Dataset length: {}".format(len(wideband_test)))
print("Data shape: {}".format(data.shape))

samples = []
labels = []
for i in range(10):
    idx = np.random.randint(len(wideband_test))
    sample, label = wideband_test[idx]
    lb = [l['class_name'] for l in label]
    samples.append(sample)
    labels.append(lb)

### Load model 
The model path is printed after training. Use the best.pt weights

In [None]:
model = YOLO(trainer.best)

In [None]:
# Inference will be saved to path printed after predict. 
results = model.predict(samples, save=True, imgsz=512, conf=0.5)

In [None]:
%matplotlib inline

In [None]:
# Plot prediction results
# Note: do not expect significant detections with default parameters. See previous note on qa datasets.
rows = 3
cols = 3
fig = plt.figure(figsize=(15, 15)) 
results_dir = results[0].save_dir

for y, result in enumerate(results[:9]):
    imgpath = os.path.join(results_dir, "image" + str(y) + ".jpg")
    fig.add_subplot(rows, cols, y + 1) 
    img = cv2.imread(imgpath)
    plt.imshow(img)
    plt.title(str(labels[y]), fontsize='small', loc='left')