In [None]:
import cv2
import numpy as np
import os
from pathlib import Path
import torch
from super_gradients.training import models
from super_gradients.training import Trainer, models
from super_gradients.training import dataloaders
from super_gradients.training.dataloaders.dataloaders import (
    coco_detection_yolo_format_train, 
    coco_detection_yolo_format_val,
)
import super_gradients.training
super_gradients.setup_device(device='cuda')
from super_gradients.training.losses import PPYoloELoss
from super_gradients.training.metrics import (
    DetectionMetrics_050,
    DetectionMetrics,
)
from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback
from tqdm.auto import tqdm


In [None]:

DEVICE = 'cuda:0'
BATCH_SIZE = 1
WORKERS = 8
#model = torch.load('/home/matthias/workspace/Coding/00_vista_medizina/bone_frac_obj_det/yoloNAS/checkpoints/yolo_nas_s/RUN_20241117_212610_007684/ckpt_epoch_100.pth', map_location=torch.device(DEVICE))

# %%
ROOT_DIR = '/home/matthias/workspace/Coding/00_vista_medizina/00_data/2024-11-21/single_class_all_categories'
train_imgs_dir = ROOT_DIR + '/train/images'
train_labels_dir = ROOT_DIR + '/train/labels'
val_imgs_dir = ROOT_DIR + '/val/images'
val_labels_dir = ROOT_DIR + '/val/labels'
test_imgs_dir = ROOT_DIR + '/test/images'
test_labels_dir = ROOT_DIR + '/test/labels'
classes = ['fracture']

# %%
dataset_params = {
    'data_dir': ROOT_DIR,
    'train_images_dir': train_imgs_dir,
    'train_labels_dir': train_labels_dir,
    'val_images_dir': val_imgs_dir,
    'val_labels_dir': val_labels_dir,
    'test_images_dir': test_imgs_dir,
    'test_labels_dir': test_labels_dir,
    'classes': classes,
    'ignore_empty_annotations': True,
}

test_data = coco_detection_yolo_format_val(
    dataset_params={
        'data_dir': dataset_params['data_dir'],
        'images_dir': dataset_params['test_images_dir'],
        'labels_dir': dataset_params['test_labels_dir'],
        'classes': dataset_params['classes'],
        'ignore_empty_annotations': dataset_params['ignore_empty_annotations'],
    },
    dataloader_params={
        'batch_size':BATCH_SIZE,
        'num_workers':WORKERS,
    }
)

dict_model_paths = {
    'f1@0.5opt': "/home/matthias/workspace/Coding/00_vista_medizina/vista_bone_frac/yoloNAS/checkpoints/yolo_nas_s/RUN_20241123_101156_209514/ckpt_best.pth",
}

In [None]:
dict_models_loaded = {}
for model_name in list(dict_model_paths.keys()):
    print("Model: ", model_name)
    dict_models_loaded.update({model_name: 
                        models.get('yolo_nas_s',
                        num_classes=80,
                        checkpoint_path=dict_model_paths[model_name])})
print("Models: ", dict_models_loaded.keys())


def load_ground_truth(gt_txt_path) -> list[list[float | int]]:
    ground_truths = []
    with open(gt_txt_path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            class_id = int(parts[0])
            # Convert from normalized to pixel coordinates
            x_min = float(parts[1]) #* img_width
            y_min = float(parts[2]) #* img_height
            x_max = float(parts[3]) #* img_width
            y_max = float(parts[4]) #* img_height
            
            # Calculate the bounding box corners (x_min, y_min, x_max, y_max)
            #x_min = int(x_center - width / 2)
            #y_min = int(y_center - height / 2)
            #x_max = int(x_center + width / 2)
            #y_max = int(y_center + height / 2)
            
            ground_truths.append([x_min, y_min, x_max, y_max, class_id])

    return ground_truths


def print_boxes_to_image(image_data: np.ndarray, boxes_info, ground_truth: bool):

    height, width, channels = image_data.shape
    
    if ground_truth:
        if len(boxes_info) == 0: 
            cv2.putText(image_data, 'No fractures in ground truth data.', (20, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)

        else:
            for box in boxes_info:
                x_center_frac, y_center_frac, width_frac, height_frac, class_id = box

                x_center = x_center_frac * width
                y_center = y_center_frac * height
                width = width_frac * width
                height = height_frac * height

                # Calculate top-left and bottom-right corners
                x1 = int(x_center - width / 2)
                y1 = int(y_center - height / 2)
                x2 = int(x_center + width / 2)
                y2 = int(y_center + height / 2)

                box_text = f'GT Class {class_id}'
                
                cv2.rectangle(image_data, (x1, y1), (x2, y2), (0, 0, 255), 2)
                cv2.putText(image_data, box_text, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)

    else:
        for idx, box in enumerate(boxes_info.prediction.bboxes_xyxy):
            x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3])

            class_id = boxes_info.prediction.labels[idx]
            score = boxes_info.prediction.confidence[idx]
            box_text = f'Pred Class {class_id}: {score:.2f}'  # Use actual label or class mapping here
            
            cv2.rectangle(image_data, (x1, y1), (x2, y2), (0, 255, 0), 2)
            cv2.putText(image_data, box_text, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)


def get_label_and_image_paths(path_images_dir: str, path_labels_dir: str) -> tuple[list]:
    list_paths_images = []
    for dirpath, _, filenames in os.walk(path_images_dir):
        
        for filename in filenames:

            if filename.endswith('.' + 'jpg'):
                list_paths_images.append(os.path.join(dirpath, filename))
    list_paths_images

    #list_paths_labels = []
    #for f in list_paths_images:
    #    list_paths_labels.append(".".join(f.split('.')[:-1]) + ".txt")

    list_paths_labels = []
    for dirpath, _, filenames in os.walk(path_labels_dir):
        
        for filename in filenames:

            if filename.endswith('.' + 'txt'):
                list_paths_labels.append(os.path.join(dirpath, filename))
    list_paths_labels

    return {
        'list_paths_images': list_paths_images, 
        'list_paths_labels': list_paths_labels,
    }


dict_paths_labels_images = get_label_and_image_paths(
    path_images_dir=dataset_params['test_images_dir'], 
    path_labels_dir=dataset_params['test_labels_dir'],
)



In [None]:
for path_image in dict_paths_labels_images['list_paths_images']:

    image = cv2.imread(path_image)

    trainer = Trainer(
            experiment_name='yolo_nas_s',
        )
    predictions = dict_models_loaded[list(dict_models_loaded.keys())[0]].predict(
        images=path_image,
        iou=0.0,
        conf=0.2,
        max_predictions=2,
    )
    #predictions = trainer.predict(path_image)


    path_gt_data = ".".join(path_image.split('/')[-1].split('.')[:-1]) + ".txt"
    path_gt_data = test_labels_dir + '/' + path_gt_data
    #path_gt_data = ".".join(path_image.split('.')[:-1]) + ".txt"
    if Path(path_gt_data).exists():
        ground_truths = load_ground_truth(path_gt_data)
        print_boxes_to_image(image_data=image, boxes_info=ground_truths, ground_truth=True)
    
    print_boxes_to_image(image_data=image, boxes_info=predictions, ground_truth=False)

    cv2.imshow('', image)
    cv2.waitKey(0)
    cv2.destroyAllWindows()