# Detection of Bleeding Regions in Wireless Capsule Endoscopy Using RT-DETR
This code does model training, deployment to ONNX, and object detection with annotation using RT-DETR

## Importing Libraries

In [8]:
#import src.misc.dist as dist 
import torch
import onnx
import os
import torch.nn as nn
from src.core import YAMLConfig 
from src.solver import TASKS
import onnxruntime as ort 
from PIL import Image, ImageDraw, ImageFont
from torchvision.transforms import ToTensor

## Paths Defined

In [9]:
config_path = '/home/ee22s501/cvip/code/configs/rtdetr/rtdetr_r101vd_6x_coco.yml'
resume_path = '/home/ee22s501/cvip/code/save/model_detect_71.pth'
file_name = '/home/ee22s501/cvip/code/save/model.onnx'
save_path = '/home/ee22s501/cvip/code/save/figs'

## Training & Validation
This configures a solver for the task specified in the YAML files. Then proceeds to train the model, with an option for validation.

In [4]:
#dist.init_distributed(backend='nccl')

cfg = YAMLConfig(config_path, use_amp=True)

solver = TASKS[cfg.yaml_cfg['task']](cfg)

solver.fit()

#Uncomment this for validation
#solver.val()

Start training


Downloading: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet101_vd_ssld_pretrained_from_paddle.pth" to /home/ee22s501/.cache/torch/hub/checkpoints/ResNet101_vd_ssld_pretrained_from_paddle.pth
100%|██████████| 163M/163M [00:11<00:00, 15.2MB/s] 


Load PResNet101 state_dict
Initial lr:  [1e-06, 0.0001]
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
number of params: 76367466
Epoch: [0]  [  0/261]  eta: 0:05:45  lr: 0.000001  loss: 48.6823 (48.6823)  loss_vfl: 0.3831 (0.3831)  loss_bbox: 2.0977 (2.0977)  loss_giou: 1.5662 (1.5662)  loss_vfl_aux_0: 0.3206 (0.3206)  loss_bbox_aux_0: 2.0852 (2.0852)  loss_giou_aux_0: 1.5705 (1.5705)  loss_vfl_aux_1: 0.3323 (0.3323)  loss_bbox_aux_1: 2.0852 (2.0852)  loss_giou_aux_1: 1.5705 (1.5705)  loss_vfl_aux_2: 0.3464 (0.3464)  loss_bbox_aux_2: 2.0828 (2.0828)  loss_giou_aux_2: 1.6179 (1.6179)  loss_vfl_aux_3: 0.3389 (0.3389)  loss_bbox_aux_3: 2.0852 (2.0852)  loss_giou_aux_3: 1.5705 (1.5705)  loss_vfl_aux_4: 0.3037 (0.3037)  loss_bbox_aux_4: 2.1102 (2.1102)  loss_giou_aux_4: 1.5705 (1.5705)  loss_vfl_aux_5: 0.2378 (0.2378)  loss_bbox_aux_5: 2.0936 (2.0936)  loss_giou_aux_5: 1.6

## Export to ONNX
Loads a model checkpoint, exports it to ONNX format for deployment, and conducts a validation check on the exported ONNX model.

In [10]:
# Load model configuration from a YAML file and resume training from a checkpoint
cfg = YAMLConfig(config_path, resume=resume_path)
checkpoint = torch.load(resume_path, map_location='cpu') 

# Extract the model state from the checkpoint
if 'ema' in checkpoint:
    state = checkpoint['ema']['module']
else:
    state = checkpoint['model']


cfg.model.load_state_dict(state)

# Define a custom model class and deploy the model and postprocessor components
class Model(nn.Module):
    def __init__(self, ) -> None:
        super().__init__()
        self.model = cfg.model.deploy()
        self.postprocessor = cfg.postprocessor.deploy()
        print(self.postprocessor.deploy_mode)
        
    def forward(self, images, orig_target_sizes):
        outputs = self.model(images)
        return self.postprocessor(outputs, orig_target_sizes)


model = Model()

# Define dynamic axes for input and export model to ONNX format
dynamic_axes = {
    'images': {0: 'N', },
    'orig_target_sizes': {0: 'N'}
}

data = torch.rand(1, 3, 640, 640)
size = torch.tensor([[640, 640]])

torch.onnx.export(
    model, 
    (data, size), 
    file_name,
    input_names=['images', 'orig_target_sizes'],
    output_names=['labels', 'boxes', 'scores'],
    dynamic_axes=dynamic_axes,
    opset_version=16, 
    verbose=False
)
# Load the exported ONNX model and perform a validation check
onnx_model = onnx.load(file_name)
onnx.checker.check_model(onnx_model)
print('Check export onnx model done...')

Load PResNet101 state_dict
True
verbose: False, log level: Level.ERROR

Check export onnx model done...


## Inference
Performs object detection on images in a test dataset directory using an ONNX model, annotates detected objects with bounding boxes and labels, and saves the annotated images.

In [13]:
test_dataset_dir = '/home/ee22s501/cvip/data_classify/val/bleeding'

# Load the model
sess = ort.InferenceSession(file_name)

# Threshold for object detection confidence
thrh = 0.4

# Load a font for labeling
fnt = ImageFont.truetype("Pillow/Tests/fonts/FreeMonoBold.ttf", 35)

# Loop through all images in the test dataset directory
for filename in os.listdir(test_dataset_dir):
    if filename.endswith(".png"):  # Assuming all images are PNG files
        image_path = os.path.join(test_dataset_dir, filename)

        # Load and preprocess the image
        im = Image.open(image_path).convert('RGB')
        im = im.resize((640, 640))
        im_data = ToTensor()(im)[None]

        # Perform inference
        output = sess.run(
            output_names=['labels', 'boxes', 'scores'],
            input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()}
        )

        labels, boxes, scores = output

        # Create an annotated image
        draw = ImageDraw.Draw(im)

        for i in range(im_data.shape[0]):
            scr = scores[i]
            lab = labels[i][scr > thrh]
            box = boxes[i][scr > thrh]

            for j, b in enumerate(box):
                label = lab[j]
                confidence = scr[j]

                if label == 1:
                    draw.rectangle(list(b), outline='blue', width=7)
                    draw.text((b[0], b[1]), text=f"Bleeding ({confidence:.2f})", font=fnt, fill='yellow', width=100)
                else:
                    draw.rectangle(list(b), outline='blue', width=7)
                    draw.text((b[0], b[1]), text=f"{label} ({confidence:.2f})", font=fnt, fill='yellow', width=100)

        # Resize and save the annotated image with the same name
        im = im.resize((224, 224))
        save_filename = os.path.join(save_path, filename)
        im.save(save_filename)

print("Bounding box annotation and saving complete.")

Bounding box annotation and saving complete.
