In [None]:
from super_gradients.training import Trainer
from super_gradients.training import dataloaders
from super_gradients.training.dataloaders.dataloaders import (
    coco_detection_yolo_format_train, 
    coco_detection_yolo_format_val
)
from super_gradients.training import models
from super_gradients.training.losses import PPYoloELoss
from super_gradients.training.metrics import (
    DetectionMetrics_050,
    DetectionMetrics_050_095
)
from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback
from tqdm.auto import tqdm
 
import os
import requests
import zipfile
import cv2
import matplotlib.pyplot as plt
import glob
import numpy as np
import random

In [None]:
ROOT_DIR = 'datasets'
train_imgs_dir = 'train/images'
train_labels_dir = 'train/labels'
val_imgs_dir = 'valid/images'
val_labels_dir = 'valid/labels'
test_imgs_dir = 'test/images'
test_labels_dir = 'test/labels'
classes = ['WS_XR_CrossShape', 'WS_XR_Tshape', 'WS_XR_Yshape', 'WS_XR_RTshape', 'WS_XR_LTshape', 'WS_ThroughST', 'WS_Merge_R', 'WS_Merge_L', 'WS_Roundabout', 'WS_CrossRailroad', 'WS_R_Curve', 'WS_L_Curve', 'WS_RL_Curve', 'WS_LL_Curve', 'WS_2way', 'WS_Ascent_Road', 'WS_Descent_Road', 'WS_Narrow_Road', 'WS_Vanish_RightRoad', 'WS_Vanish_LeftRoad', 'WS_Pass_R', 'WS_Pass_Both', 'WS_MedianStrip_Start', 'WS_MedianStrip_End', 'WS_Flag', 'WS_Slippery', 'WS_Riverside', 'WS_UnevenSurface', 'WS_SpeedBump', 'WS_Rockslide', 'WS_Crosswalk', 'WS_Children', 'WS_Bicycle', 'WS_RC', 'WS_Airplane', 'WS_Sidewind', 'WS_Tunnel', 'WS_Bridge', 'WS_Wild_Animal', 'WS_Danger', 'WS_CCS', 'RS_E_TEMPPause', 'RS_E_Square', 'RS_E_Triangle', 'RS_E_TrafficCone', 'RS_E_Drum', 'RS_E_PEFence', 'RS_E_SignalVehicle', 'RS_E_ForkCrane', 'RS_E_PEFence2', 'RS_E_SCFence', 'RS_E_ETHeavy', 'I_Py_PEDmall', 'I_Py_Crosswalk', 'I_Py_P_Older', 'I_Py_P_Children', 'I_Py_P_DisablePerson', 'I_Py_BicycleCrossing', 'I_C_Road_Car', 'I_C_Road_Bicycle', 'I_C_Road_Bicycle_PED', 'I_C_Roundabout', 'I_C_STR', 'I_C_STR_RT', 'I_C_RT', 'I_C_LT', 'I_C_STR_LT', 'I_C_LT_Uturn', 'I_C_LT_RT', 'I_C_Uturn', 'I_C_Bothpass', 'I_C_Rightpass', 'I_C_Leftpass', 'I_C_Diversion', 'I_C_Bicycle_PED', 'I_C_Car_Bicycle', 'I_S_TrafficByDirection', 'I_S_Road_Bicycle', 'I_S_PL', 'I_S_CyclePL', 'I_S_Oneway', 'I_S_LT_Caution', 'I_S_Road_Bus', 'I_S_Road_HOV', 'I_S_Pass_first', 'TC_H_3color', 'TC_V_3color', 'TC_H_4color', 'TC_V_4color', 'TC_Y_Flasher', 'TC_Red_flasher', 'RS_T_Slow', 'RS_T_Yield', 'RS_C_PROH_Pass', 'RS_C_PROH_CarPass', 'RS_C_PROH_TruckPass', 'RS_C_PROH_OmnibusPass', 'RS_C_PROH_TWMV', 'RS_C_PROH_WMV', 'RS_C_PROH_TAH', 'RS_C_PROH_Bicycle', 'RS_C_PROH_Entry', 'RS_C_PROH_STR', 'RS_C_PROH_RT', 'RS_C_PROH_LT', 'RS_C_PROH_Uturn', 'RS_C_PROH_Overtaking', 'RS_C_PROH_TEMPStop', 'RS_C_PROH_STOP', 'RS_C_Limit_Weight', 'RS_C_Limit_Height', 'RS_C_Limit_Breadth', 'RS_C_DistanceWithCar', 'RS_C_Limit_MaxSpeed', 'RS_C_Limit_MinSpeed', 'RS_C_PROH_Walk', 'RS_C_PROH_LoadHazard'] 
 
dataset_params = {
    'data_dir':ROOT_DIR,
    'train_images_dir':train_imgs_dir,
    'train_labels_dir':train_labels_dir,
    'val_images_dir':val_imgs_dir,
    'val_labels_dir':val_labels_dir,
    'test_images_dir':test_imgs_dir,
    'test_labels_dir':test_labels_dir,
    'classes':classes,
}

EPOCHS = 50
BATCH_SIZE = 30
WORKERS = 8

In [None]:
train_data = coco_detection_yolo_format_train(
    dataset_params={
        'data_dir': dataset_params['data_dir'],
        'images_dir': dataset_params['train_images_dir'],
        'labels_dir': dataset_params['train_labels_dir'],
        'classes': dataset_params['classes'],
        'input_dim': [1280, 1280],

    },
    dataloader_params={
        'batch_size':BATCH_SIZE,
        'num_workers':WORKERS
    }
)
 
val_data = coco_detection_yolo_format_val(
    dataset_params={
        'data_dir': dataset_params['data_dir'],
        'images_dir': dataset_params['val_images_dir'],
        'labels_dir': dataset_params['val_labels_dir'],
        'classes': dataset_params['classes'],
        'input_dim': [1280, 1280],

    },
    dataloader_params={
        'batch_size':BATCH_SIZE,
        'num_workers':WORKERS
    }
)

In [None]:
train_data.dataset.transforms

In [None]:
train_data.dataset.transforms.pop(2)

In [None]:
train_data.dataset.plot(plot_transformed_data=True)

In [None]:
train_params = {
    'silent_mode': False,
    "average_best_models":True,
    "warmup_mode": "linear_epoch_step",
    "warmup_initial_lr": 1e-6,
    "lr_warmup_epochs": 3,
    "initial_lr": 5e-4,
    "lr_mode": "cosine",
    "cosine_final_lr_ratio": 0.1,
    "optimizer": "Adam",
    "optimizer_params": {"weight_decay": 0.0001},
    "zero_weight_decay_on_bias_and_bn": True,
    "ema": True,
    "ema_params": {"decay": 0.9, "decay_type": "threshold"},
    "max_epochs": EPOCHS,
    "mixed_precision": True,
    "loss": PPYoloELoss(
        use_static_assigner=False,
        num_classes=len(dataset_params['classes']),
        reg_max=16
    ),
    "valid_metrics_list": [
        DetectionMetrics_050(
            score_thres=0.1,
            top_k_predictions=300,
            num_cls=len(dataset_params['classes']),
            normalize_targets=True,
            post_prediction_callback=PPYoloEPostPredictionCallback(
                score_threshold=0.01,
                nms_top_k=1000,
                max_predictions=300,
                nms_threshold=0.7
            )
        ),
        DetectionMetrics_050_095(
            score_thres=0.1,
            top_k_predictions=300,
            num_cls=len(dataset_params['classes']),
            normalize_targets=True,
            post_prediction_callback=PPYoloEPostPredictionCallback(
                score_threshold=0.01,
                nms_top_k=1000,
                max_predictions=300,
                nms_threshold=0.7
            )
        )
    ],
    "metric_to_watch": 'mAP@0.50:0.95'
}

In [None]:
models_to_train = [
    'yolo_nas_s',
    'yolo_nas_m',
    'yolo_nas_l'
]
 
CHECKPOINT_DIR = 'checkpoints'
 
for model_to_train in models_to_train:
    trainer = Trainer(
        experiment_name=model_to_train, 
        ckpt_root_dir=CHECKPOINT_DIR
    )
 
    model = models.get(
        model_to_train, 
        num_classes=len(dataset_params['classes']), 
        pretrained_weights="coco"
    )
 
    trainer.train(
        model=model, 
        training_params=train_params, 
        train_loader=train_data, 
        valid_loader=val_data
    )