# Automatic Classification & Detection of Bleeding and Non-Bleeding frames in Wireless Capsule Endoscopy Using Swin Transformer and RT-DETR
This code classifies frames for bleeding and non-bleeding and, if it is classified as bleeding, it annotates bleeding regions with bounding boxes and labels. Swin Transformer model for classification and ONNX runtime for object detection was used.

## Importing Libraries

In [1]:
import torch
import os
import cv2
import onnxruntime as ort 
from PIL import Image, ImageDraw, ImageFont
from torchvision.transforms import ToTensor
from src.zoo.swin_transformer import SwinTransformer

## Classification, Detection and Segmentation Pipeline

In [12]:
image_path = '/home/ee22s501/cvip/data_test/Test Dataset 2/A0336.png'
model_path = '/home/ee22s501/cvip/code/save/model_classify_72.pth'
file_name = '/home/ee22s501/cvip/code/save/model.onnx'
save_path = '/home/ee22s501/cvip/code/save/figs'  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the Swin Transformer model configuration.
swin_config = {
    'img_size': 224,
    'patch_size': 4,
    'in_chans': 3,
    'num_classes': 2,
    'embed_dim': 96,
    'depths': [2, 2, 6, 2],
    'num_heads': [3, 6, 12, 24],
    'window_size': 7,
    'mlp_ratio': 4,
    'stochastic_depth_prob': 0.2,
}

# Instantiate Swin Transformer model. 
model = SwinTransformer(img_size=swin_config['img_size'],
                        patch_size=swin_config['patch_size'],
                        in_chans=swin_config['in_chans'],
                        num_classes=swin_config['num_classes'],
                        embed_dim=swin_config['embed_dim'],
                        depths=swin_config['depths'],
                        num_heads=swin_config['num_heads'],
                        window_size=swin_config['window_size'],
                        mlp_ratio=swin_config['mlp_ratio'],
                        qkv_bias=True,
                        qk_scale=None,
                        drop_rate=0.0,
                        drop_path_rate=0.1,
                        ape=False,
                        patch_norm=True,
                        use_checkpoint=False,
                        fused_window_process=False).to(device)

model.load_state_dict(torch.load(model_path))
model.eval()

# Define class names
class_names = ['bleeding', 'non_bleeding']
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
font_color = (255, 255, 255)
font_thickness = 1

# Function to predict whether an image is bleeding or not
def predict_image(image_path):
    image = cv2.imread(image_path)
    image_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0
    image_tensor = image_tensor.to(device)
    with torch.no_grad():
        output = model(image_tensor)
        _, predicted = torch.max(output.data, 1)
    predicted_label = class_names[predicted.item()]

    return predicted_label
    
# Predict the label for the input image.
predicted_label = predict_image(image_path)


# Check if the image is classified as bleeding
if predicted_label == 'bleeding':
    # Initialize ONNX inference session.
    sess = ort.InferenceSession(file_name)

    # Threshold for object detection confidence
    thrh = 0.4

    # Load a font for labeling
    fnt = ImageFont.truetype("Pillow/Tests/fonts/FreeMonoBold.ttf", 35)

    # Load and preprocess the image
    im = Image.open(image_path).convert('RGB')
    im = im.resize((640, 640))
    im_data = ToTensor()(im)[None]
    size = torch.tensor([[640, 640]])

    # Perform inference
    output = sess.run(
        output_names=['labels', 'boxes', 'scores'],
        input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()}
    )

    labels, boxes, scores = output

    # Create an annotated image
    draw = ImageDraw.Draw(im)

    for i in range(im_data.shape[0]):
        scr = scores[i]
        lab = labels[i][scr > thrh]
        box = boxes[i][scr > thrh]

        for j, b in enumerate(box):
            label = lab[j]
            confidence = scr[j]

            if label == 1:
                draw.rectangle(list(b), outline='blue', width=7)
                draw.text((b[0], b[1]), text=f"\n \n \nBleeding \n({confidence:.2f})", font=fnt, fill='yellow', width=100)
            else:
                draw.rectangle(list(b), outline='blue', width=7)
                draw.text((b[0], b[1]), text=f"{label} ({confidence:.2f})", font=fnt, fill='yellow', width=100)

    # Resize and save the annotated image with the same name
    im = im.resize((224, 224))
    save_filename = os.path.join(save_path, os.path.basename(image_path))
    im.save(save_filename)

    print("Image is bleeding and bounding box annotation is complete.")
else:
    print("The image is classified as non-bleeding.")

Image is bleeding and bounding box annotation is complete.
