# Example 07 - Sig53 with YOLOv8 Classifier
This notebook showcases using the Sig53 dataset to train a YOLOv8 classification model.

---

## Import Libraries
We will import all the usual libraries, in addition to Ultralytics. You can install Ultralytics with:
```bash
pip install ultralytics
```

In [None]:
# Packages for Training
from torchsig.utils.yolo_classify import *
from torchsig.utils.classify_transforms import real_imag_vstacked_cwt_image, complex_iq_to_heatmap
import yaml
from PIL import Image

In [None]:
# Packages for testing/inference
from torchsig.datasets.modulations import ModulationsDataset
from torchsig.transforms.target_transforms import DescToFamilyName
from torchsig.transforms.transforms import Spectrogram, SpectrogramImage, Normalize, Compose, Identity
from ultralytics import YOLO
from PIL import Image

## Prepare YOLO classificatoin trainer and Model
Datasets are generated on the fly in a way that is Ultralytics YOLO compatible. See [Ultralytics: Train Custom Data - Organize Directories](https://docs.ultralytics.com/yolov5/tutorials/train_custom_data/#23-organize-directories) to learn more. 

Additionally, we create a yaml file for dataset configuration. See "classify.yaml" in Torchsig Examples.

Download desired YOLO model from [Ultralytics Models](https://docs.ultralytics.com/models/). We will use YOLOv8, specifically `yolov8n-cls.pt`

---

In [None]:
config_path = 'classify.yaml'
with open(config_path, 'r') as file:
    config = yaml.safe_load(file)

overrides = config['overrides']

### Explanation of the `overrides` Dictionary

The `overrides` dictionary is used to customize the settings for the Ultralytics YOLO trainer by specifying specific values that override the default configurations. The dictionary is imported from `classify.yaml`. However, you can customize in the notebook. 

Example:

```python
overrides = {'model': 'yolov8n-cls.pt', 'epochs': 100, 'data': 'classify.yaml', 'device': 0, 'imgsz': 64}
```
A .yaml is necessary for training. Look at `classify.yaml` in the examples directory. It will contain the path to your torchsig data.

### Explanation of `image_transform` function
`YoloClassifyTrainer` allows you to pass in any transform that takes in complex I/Q and outputs an image for training. Some example transforms can be found in torchsig.utils.classify_transforms. If nothing is passed, it will default to spectrogram images. It is important to update `overrides` so that your imgsz matches output.

### Build YoloClassifyTrainer
This will instantiate the YOLO classification trainer with overrides specified above.

In [None]:
trainer = YoloClassifyTrainer(overrides=overrides, image_transform=None)

### The will begin training

In [None]:
trainer.train()

### Instantiate Test Dataset

Uses Torchsig's `ModulationsDataset` to generate a narrowband classification dataset. 

In [None]:
# Determine whether to map descriptions to family names
if config['family']:
    target_transform = CP([DescToFamilyName()])
else:
    target_transform = None

transform = Compose([
    Spectrogram(nperseg=overrides['imgsz'], noverlap=0, nfft=overrides['imgsz'], mode='psd'),
    Normalize(norm=np.inf, flatten=True),
    SpectrogramImage(), 
    ])

class_list = [item[1] for item in config['names'].items()]

dataset = ModulationsDataset(
    classes=class_list,
    use_class_idx=False,
    level=config['level'],
    num_iq_samples=overrides['imgsz']**2,
    num_samples=int(config['nc'] * 10),
    include_snr=config['include_snr'],
    transform=transform,
    target_transform=target_transform
)

# Retrieve a sample and print out information
idx = np.random.randint(len(dataset))
data, label = dataset[idx]
print("Dataset length: {}".format(len(dataset)))
print("Data shape: {}".format(data.shape))

samples = []
labels = []
for i in range(10):
    idx = np.random.randint(len(dataset))
    sample, label = dataset[idx]
    samples.append(sample)
    labels.append(label)

### Predictions / Inference
The following cells show you how to load the 'best.pt' weights from your training for prediction

In [None]:
model_path = 'YOUR_PROJECT_NAME/YOUR_CLASSIFY_EXPERIMENT/weights/best.pt'  #replace with your path to 'best.pt'
model = YOLO(model_path) #The model will remember the configuration from training

In [None]:
results = model.predict(samples)

In [None]:
# Process results list
for y, result in enumerate(results):
    probs = result.probs  # Probs object for classification outputs
    print(f'Actual Labels -> {labels[y]}')
    print(f'Top 1 Prediction ->  {result.names[probs.top1]}, {probs.top1conf}')
    print(f'Top 5 Prediction ->  {result.names[probs.top5[0]]},{result.names[probs.top5[1]]},{result.names[probs.top5[2]]},{result.names[probs.top5[3]]},{result.names[probs.top5[4]]}, {list(probs.top5conf.cpu().numpy())}')

    img = Image.fromarray(result.orig_img)
    img.show()# display to screen
