<a href="https://colab.research.google.com/github/TharinsaMudalige/Neuron-Brain_Tumor_Detection_Classification_with_XAI/blob/Detection-Classficiation-CNN/Generating_Annotations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Import libraries & Mount Google Drive

In [1]:
!pip install segmentation-models-pytorch torch torchvision albumentations opencv-python lxml

import os
import cv2
import torch
import numpy as np
import xml.etree.ElementTree as ET
from xml.dom.minidom import parseString
import albumentations as A
import segmentation_models_pytorch as smp
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

Collecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.4.0-py3-none-any.whl.metadata (32 kB)
Collecting efficientnet-pytorch>=0.6.1 (from segmentation-models-pytorch)
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pretrainedmodels>=0.7.1 (from segmentation-models-pytorch)
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cud

Define Paths

In [9]:
DATASET_PATH = "/content/drive/MyDrive/DSGP/Preprocessed_Dataset"
OUTPUT_PATH = "/content/drive/MyDrive/DSGP/CNN_Dataset"

# Ensure output directories exist
for split in ["Train", "Val", "Test"]:
    os.makedirs(os.path.join(OUTPUT_PATH, split, "Images"), exist_ok=True)
    os.makedirs(os.path.join(OUTPUT_PATH, split, "Annotations"), exist_ok=True)

Load Pretrained U-Net Model

In [10]:
# Load pretrained U-Net model for segmentation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = smp.Unet(
    encoder_name="resnet34",      # Pretrained backbone
    encoder_weights="imagenet",   # Pretrained weights
    in_channels=3,                # RGB Images
    classes=1,                    # Single class (tumor mask)
).to(device)

model.eval()  # Set model to evaluation mode

Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

Function to Preprocess Images

In [11]:
def preprocess_image(image_path, target_size=(256, 256)):
    """Load and preprocess image for U-Net segmentation."""
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, target_size)
    image = image / 255.0  # Normalize to [0,1]
    image = np.transpose(image, (2, 0, 1))  # Change to (C, H, W)
    return torch.tensor(image, dtype=torch.float32).unsqueeze(0).to(device)


Function to Get Bounding Box from Segmentation Mask


In [12]:
def get_bounding_box(mask):
    """Extract bounding box from segmentation mask using contour detection."""
    mask = (mask > 0.5).astype(np.uint8)  # Threshold mask
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    if not contours:
        return None  # No tumor detected

    x, y, w, h = cv2.boundingRect(max(contours, key=cv2.contourArea))
    return [x, y, x + w, y + h]  # Return [xmin, ymin, xmax, ymax]


Function to Create PASCAL VOC XML Annotation

In [13]:
def create_pascal_voc_xml(image_path, bbox, label, save_dir):
    """Generate XML annotations in PASCAL VOC format."""
    image_name = os.path.basename(image_path)
    xml_filename = os.path.splitext(image_name)[0] + ".xml"

    img = cv2.imread(image_path)
    height, width, _ = img.shape

    root = ET.Element("annotation")
    ET.SubElement(root, "folder").text = "Dataset"
    ET.SubElement(root, "filename").text = image_name
    ET.SubElement(root, "path").text = image_path

    size = ET.SubElement(root, "size")
    ET.SubElement(size, "width").text = str(width)
    ET.SubElement(size, "height").text = str(height)
    ET.SubElement(size, "depth").text = str(3)

    obj = ET.SubElement(root, "object")
    ET.SubElement(obj, "name").text = label
    bbox_elem = ET.SubElement(obj, "bndbox")

    ET.SubElement(bbox_elem, "xmin").text = str(bbox[0])
    ET.SubElement(bbox_elem, "ymin").text = str(bbox[1])
    ET.SubElement(bbox_elem, "xmax").text = str(bbox[2])
    ET.SubElement(bbox_elem, "ymax").text = str(bbox[3])

    xml_str = ET.tostring(root)
    xml_pretty = parseString(xml_str).toprettyxml()

    with open(os.path.join(save_dir, xml_filename), "w") as xml_file:
        xml_file.write(xml_pretty)

Process Images and Generate Annotations

In [14]:
def process_and_split_data():
    """Generate U-Net segmentation masks, extract bounding boxes, and create annotations."""
    for split in ["Train", "Val", "Test"]:
        input_images_path = os.path.join(DATASET_PATH, split)
        output_images_path = os.path.join(OUTPUT_PATH, split, "Images")
        output_annotations_path = os.path.join(OUTPUT_PATH, split, "Annotations")

        for tumor_class in os.listdir(input_images_path):
            class_images_path = os.path.join(input_images_path, tumor_class)

            if not os.path.isdir(class_images_path):
                continue  # Skip non-folder files

            images = [img for img in os.listdir(class_images_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]

            image_save_path = os.path.join(output_images_path, tumor_class)
            annotation_save_path = os.path.join(output_annotations_path, tumor_class)

            os.makedirs(image_save_path, exist_ok=True)
            os.makedirs(annotation_save_path, exist_ok=True)

            for image_file in images:
                image_path = os.path.join(class_images_path, image_file)

                # Process image and get segmentation mask
                input_image = preprocess_image(image_path)
                with torch.no_grad():
                    mask = model(input_image).squeeze().cpu().numpy()

                # Extract bounding box
                bbox = get_bounding_box(mask)

                if bbox:
                    # Copy image to the new dataset structure
                    shutil.copy(image_path, os.path.join(image_save_path, image_file))

                    # Save annotation
                    create_pascal_voc_xml(image_path, bbox, tumor_class, annotation_save_path)

                print(f"Processed {image_file} -> {tumor_class}")

process_and_split_data()
print("Dataset processing & annotation generation completed successfully!")

FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/DSGP/Preprocessed_Dataset/Train'

Visualizing Segmentation for an Image

In [None]:
def visualize_segmentation(image_path):
    """Display original image and its corresponding segmentation mask."""

    # Load and preprocess the image
    input_image = preprocess_image(image_path)

    # Generate segmentation mask using U-Net
    with torch.no_grad():
        mask = model(input_image).squeeze().cpu().numpy()

    # Convert mask to binary format
    binary_mask = (mask > 0.5).astype(np.uint8)

    # Load original image
    original_image = cv2.imread(image_path)
    original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)

    # Resize mask to match original image size
    mask_resized = cv2.resize(binary_mask, (original_image.shape[1], original_image.shape[0]))

    # Create an overlay of the mask on the original image
    overlay = original_image.copy()
    overlay[mask_resized > 0] = [255, 0, 0]  # Red color for mask

    # Plot the results
    fig, ax = plt.subplots(1, 3, figsize=(12, 5))

    ax[0].imshow(original_image)
    ax[0].set_title("Original Image")
    ax[0].axis("off")

    ax[1].imshow(mask_resized, cmap="gray")
    ax[1].set_title("Segmented Mask")
    ax[1].axis("off")

    ax[2].imshow(overlay)
    ax[2].set_title("Overlay Mask on Image")
    ax[2].axis("off")

    plt.show()

# Example usage (Change the path to an actual image from your dataset)
example_image_path = "/content/drive/MyDrive/DSGP/Preprocessed_Dataset/Train/astrocitoma/sample.png"
visualize_segmentation(example_image_path)

Visualize a Segmented Image

In [None]:
import matplotlib.pyplot as plt

def visualize_segmentation(image_path):
    """Display original image and its corresponding segmentation mask."""

    # Load and preprocess the image
    input_image = preprocess_image(image_path)

    # Generate segmentation mask using U-Net
    with torch.no_grad():
        mask = model(input_image).squeeze().cpu().numpy()

    # Convert mask to binary format
    binary_mask = (mask > 0.5).astype(np.uint8)

    # Load original image
    original_image = cv2.imread(image_path)
    original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)

    # Resize mask to match original image size
    mask_resized = cv2.resize(binary_mask, (original_image.shape[1], original_image.shape[0]))

    # Create an overlay of the mask on the original image
    overlay = original_image.copy()
    overlay[mask_resized > 0] = [255, 0, 0]  # Red color for mask

    # Plot the results
    fig, ax = plt.subplots(1, 3, figsize=(12, 5))

    ax[0].imshow(original_image)
    ax[0].set_title("Original Image")
    ax[0].axis("off")

    ax[1].imshow(mask_resized, cmap="gray")
    ax[1].set_title("Segmented Mask")
    ax[1].axis("off")

    ax[2].imshow(overlay)
    ax[2].set_title("Overlay Mask on Image")
    ax[2].axis("off")

    plt.show()

# Example usage (Change the path to an actual image from your dataset)
example_image_path = "/content/drive/MyDrive/DSGP/Preprocessed_Dataset/Train/astrocitoma/sample.png"
visualize_segmentation(example_image_path)