In [None]:
import torch
from super_gradients.training import Trainer
from super_gradients.training import dataloaders
from super_gradients.training.dataloaders.dataloaders import (
    coco_detection_yolo_format_train,
    coco_detection_yolo_format_val,
)
from super_gradients.training import models
from super_gradients.training.losses import PPYoloELoss
from super_gradients.training.metrics import (
    DetectionMetrics_050,
    DetectionMetrics_050_095
)
from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback
from tqdm.auto import tqdm

import os
import requests
import zipfile
import cv2
import matplotlib.pyplot as plt
import glob
import numpy as np
import random


In [None]:
# unzip the data file
def unzip(zip_file=None, dst=None):
    try:
        with zipfile.ZipFile(zip_file) as zip_ref:
            zip_ref.extractall(dst)
            print("Extracted all")
    except:
        print('invalid file')
unzip('/datasets/hituav-a-highaltitude-infrared-thermal-dataset.zip',
      '/datasets')

In [None]:
os.makedirs('/workspace/inference_results/images', exist_ok=True)

In [None]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
model = models.get(
    model_name = models.Models.YOLO_NAS_S,
    checkpoint_path= ...,
    num_classes=5
    ).to(device)

# Inference on Test images

In [None]:
ROOT_TEST = '/datasets/hit-uav/images/test/'
all_images = os.listdir(ROOT_TEST)

In [None]:
for image in tqdm(all_images, total=len(all_images)):
    image_path = os.path.join(ROOT_TEST, image)
    out = model.predcit(image_path)
    out.save('inference_results/images')
    os.rename(
        os.path.join(ROOT_TEST, 'pred_0.jpg'),
        os.path.join(ROOT_TEST, image)
    )

# Overlapped Ground Truth Samples

In [None]:
classes = ['Person', 'Car', 'Bicycle', 'OtherVechicle', 'DontCare']

In [None]:
colors = np.random.uniform( 0, 255, size=(len(classes), 3) )

In [None]:
def yolo2bbox(bboxes):
    xmin, ymin = bboxes[0] - bboxes[2]/2, bboxes[1] - bboxes[3]/2
    xmax, ymax = bboxes[0] + bboxes[2]/2, bboxes[1] + bboxes[3]/2
    return xmin, ymin, xmax, ymax

In [None]:
def plot_box(image, bboxes, labels):
    # Need the image height and width to denormalize
    # the bounding box coordinates
    height, width = image.shape[:2]
    # lw = max(round(sum(image.shape / 2 * 0.003)), 2) # line width
    lw = max(round( sum(image.shape) / 2 * 0.003 ),2)
    tf = max(lw - 1, 1)
    for box_num, box in enumerate(bboxes):
        x1, y1, x2, y2 = yolo2bbox(box)
        # denormalize the coordinates
        xmin = int(x1*width)
        ymin = int(y1*height)
        xmax = int(x2*width)
        ymax = int(y2*height)

        p1, p2 = (int(xmin), int(ymin)), (int(xmax), int(ymax))

        class_name = classes[int(labels[box_num])]

        color = colors[classes.index(class_name)]

        cv2.rectangle(image, p1, p2, color=color, thickness=lw, lineType=cv2.LINE_AA)

        # for filled rectangle
        w, h = cv2.getTextSize(class_name,
                               0,
                               fontScale=lw/3,
                               thickness=tf)[0]
        outside = p1[1] - h >= 3
        p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
        # p2 = p1[0] + w, (p1[1] - h - 3 if outside else p1[1] + h + 3)

        cv2.rectangle(image, p1, p2, color=color, thickness=-1, lineType=cv2.LINE_AA)
        cv2.putText(image,
                    class_name,
                    (p1[0], p1[1] -5 if outside else p1[1] + h + 2),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    fontScale=lw/3.5,
                    color=(255,255,255),
                    thickness=tf,
                    lineType=cv2.LINE_AA
                )
    return image