# Example 07 - Sig53 with YOLOv8 Classifier
This notebook showcases using the Sig53 dataset to train a YOLOv8 classification model.

---

## Import Libraries
We will import all the usual libraries, in addition to Ultralytics. You can install Ultralytics with:
```bash
pip install ultralytics
```

In [1]:
# Packages for Training
from torchsig.utils.yolo_classify import *
from torchsig.utils.classify_transforms import real_imag_vstacked_cwt_image, complex_iq_to_heatmap
import yaml
from PIL import Image

In [2]:
# Packages for testing/inference
from torchsig.datasets.modulations import ModulationsDataset
from torchsig.transforms.target_transforms import DescToFamilyName
from torchsig.transforms.transforms import Spectrogram, SpectrogramImage, Normalize, Compose, Identity
from ultralytics import YOLO
from PIL import Image

## Prepare YOLO classificatoin trainer and Model
Datasets are generated on the fly in a way that is Ultralytics YOLO compatible. See [Ultralytics: Train Custom Data - Organize Directories](https://docs.ultralytics.com/yolov5/tutorials/train_custom_data/#23-organize-directories) to learn more. 

Additionally, we create a yaml file for dataset configuration. See "classify.yaml" in Torchsig Examples.

Download desired YOLO model from [Ultralytics Models](https://docs.ultralytics.com/models/). We will use YOLOv8, specifically `yolov8n-cls.pt`

---

In [3]:
config_path = 'classify.yaml'
with open(config_path, 'r') as file:
    config = yaml.safe_load(file)

overrides = config['overrides']

### Explanation of the `overrides` Dictionary

The `overrides` dictionary is used to customize the settings for the Ultralytics YOLO trainer by specifying specific values that override the default configurations. The dictionary is imported from `classify.yaml`. However, you can customize in the notebook. 

Example:

```python
overrides = {'model': 'yolov8n-cls.pt', 'epochs': 100, 'data': 'classify.yaml', 'device': 0, 'imgsz': 64}
```
A .yaml is necessary for training. Look at `classify.yaml` in the examples directory. It will contain the path to your torchsig data.

### Explanation of `image_transform` function
`YoloClassifyTrainer` allows you to pass in any transform that takes in complex I/Q and outputs an image for training. Some example transforms can be found in torchsig.utils.classify_transforms. If nothing is passed, it will default to spectrogram images. It is important to update `overrides` so that your imgsz matches output.

### Build YoloClassifyTrainer
This will instantiate the YOLO classification trainer with overrides specified above.

In [4]:
trainer = YoloClassifyTrainer(overrides=overrides, image_transform=None)

Ultralytics 8.3.3 🚀 Python-3.10.12 torch-2.4.1+cu121 CUDA:0 (NVIDIA GeForce RTX 4090 Laptop GPU, 15981MiB)


[34m[1mengine/trainer: [0mtask=classify, mode=train, model=yolov8n-cls.pt, data=classify.yaml, epochs=1, time=None, patience=100, batch=32, imgsz=64, save=True, save_period=-1, cache=False, device=0, workers=32, project=YOUR_PROJECT_NAME, name=YOUR_CLASSIFY_EXPERIMENT5, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, vid_stride=1, stream_buffer=False, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, embed=None, show=False, save_frames=False, save_txt=False, save_conf=False, save_crop=False, show_labels=True, show_conf=True, show_boxes=True, line_width=None, format=torchscript, ker

### The will begin training

In [5]:
trainer.train()

[34m[1mTensorBoard: [0mStart with 'tensorboard --logdir YOUR_PROJECT_NAME/YOUR_CLASSIFY_EXPERIMENT5', view at http://localhost:6006/
Overriding model.yaml nc=1000 with nc=53

                   from  n    params  module                                       arguments                     
  0                  -1  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]                 
  1                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]                
  2                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]             
  3                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]                
  4                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]             
  5                  -1  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]               
  6                  -1 

2025/01/14 13:34:32 INFO mlflow.tracking.fluent: Autologging successfully enabled for tensorflow.
2025/01/14 13:34:32 INFO mlflow.tracking.fluent: Autologging successfully enabled for keras.


[34m[1mMLflow: [0mlogging run_id(f5a8b67d56a549c19aef60502a6343ea) to runs/mlflow
[34m[1mMLflow: [0mview at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri runs/mlflow'
[34m[1mMLflow: [0mdisable with 'yolo settings mlflow=False'
CWSpikeDataset._generate_samples called
SPIKEDataset
CWSpikeDataset._generate_samples called
SPIKEDataset
CWSpikeDataset._generate_samples called
SPIKEDataset
CWSpikeDataset._generate_samples called
SPIKEDataset
CWSpikeDataset._generate_samples called
SPIKEDataset
CWSpikeDataset._generate_samples called

SPIKEDatasetCWSpikeDataset._generate_samples called
SPIKEDataset
CWSpikeDataset._generate_samples called
SPIKEDataset
CWSpikeDataset._generate_samples called
SPIKEDataset
CWSpikeDataset._generate_samples called
SPIKEDataset
[34m[1mTensorBoard: [0mmodel graph visualization added ✅
Image sizes 64 train, 64 val
Using 32 dataloader workers
Logging results to [1mYOUR_PROJECT_NAME/YOUR_CLASSIFY_EXPERIMENT5[0m
Starting training for 1 epochs

  0%|          | 0/20 [00:00<?, ?it/s]


KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/mtwente/.local/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/home/mtwente/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/mtwente/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/mtwente/.local/lib/python3.10/site-packages/torchsig/utils/yolo_classify.py", line 115, in __getitem__
    j = self.class_to_idx_dict[label_name]  # Get class index from the label name
KeyError: 'fm'


### Instantiate Test Dataset

Uses Torchsig's `ModulationsDataset` to generate a narrowband classification dataset. 

In [None]:
# Determine whether to map descriptions to family names
if config['family']:
    target_transform = CP([DescToFamilyName()])
else:
    target_transform = None

transform = Compose([
    Spectrogram(nperseg=overrides['imgsz'], noverlap=0, nfft=overrides['imgsz'], mode='psd'),
    Normalize(norm=np.inf, flatten=True),
    SpectrogramImage(), 
    ])

class_list = [item[1] for item in config['names'].items()]

dataset = ModulationsDataset(
    classes=class_list,
    use_class_idx=False,
    level=config['level'],
    num_iq_samples=overrides['imgsz']**2,
    num_samples=int(config['nc'] * 10),
    include_snr=config['include_snr'],
    transform=transform,
    target_transform=target_transform
)

# Retrieve a sample and print out information
idx = np.random.randint(len(dataset))
data, label = dataset[idx]
print("Dataset length: {}".format(len(dataset)))
print("Data shape: {}".format(data.shape))

samples = []
labels = []
for i in range(10):
    idx = np.random.randint(len(dataset))
    sample, label = dataset[idx]
    samples.append(sample)
    labels.append(label)

### Predictions / Inference
The following cells show you how to load the 'best.pt' weights from your training for prediction

In [None]:
model_path = 'YOUR_PROJECT_NAME/YOUR_CLASSIFY_EXPERIMENT/weights/best.pt'  #replace with your path to 'best.pt'
model = YOLO(model_path) #The model will remember the configuration from training

In [None]:
results = model.predict(samples)

In [None]:
# Process results list
for y, result in enumerate(results):
    probs = result.probs  # Probs object for classification outputs
    print(f'Actual Labels -> {labels[y]}')
    print(f'Top 1 Prediction ->  {result.names[probs.top1]}, {probs.top1conf}')
    print(f'Top 5 Prediction ->  {result.names[probs.top5[0]]},{result.names[probs.top5[1]]},{result.names[probs.top5[2]]},{result.names[probs.top5[3]]},{result.names[probs.top5[4]]}, {list(probs.top5conf.cpu().numpy())}')

    img = Image.fromarray(result.orig_img)
    img.show()# display to screen
