First,install package roifile to read the data of ImageJ roi files,
github address：https://github.com/cgohlke/roifile/tree/master

In [None]:
pip install -U roifile[all]

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from roifile import ImagejRoi, ROI_TYPE, ROI_OPTIONS
from bs4 import BeautifulSoup
import torchvision
from torchvision import transforms, datasets, models
from torchvision.datasets import ImageFolder
import torch
from torch.utils.data import DataLoader, random_split
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
from PIL import Image
import os
from google.colab import drive
drive.mount('/content/drive')
from google.colab import files
import io
import cv2
import time


In [None]:
#How roifile package read the roi data from ImageJ roi file
roi_instance = ImagejRoi.fromfile('/content/drive/MyDrive/object_detection/archive/annotations_imagej_roi/maksssksksss0.zip')
print(roi_instance)

In [None]:
#Define the function to read the images and annotation from ImageJ roi files and convert to the input format in pytorch
def generate_box(obj):
  xmin = obj.left
  ymin = obj.top
  xmax = obj.right
  ymax = obj.bottom
  return [xmin, ymin, xmax, ymax]

def generate_label(obj):
  if not isinstance(obj.name, str):
    return 0
  if 'with_mask' in obj.name:
      return 1
  elif 'mask_weared_incorrect' in obj.name:
      return 2
  elif 'without_mask' in obj.name:
      return 3
  else:
      return 0

def generate_target(image_id, roi_file_path):
  # In pytorch, the input should be [xmin, ymin, xmax, ymax]
  boxes = []
  labels = []
  roi_objects = ImagejRoi.fromfile(roi_file_path)
  for i in roi_objects:
    boxes.append(generate_box(i))
    labels.append(generate_label(i))

  boxes = torch.as_tensor(boxes, dtype=torch.float32)
  labels = torch.as_tensor(labels, dtype=torch.int64)
  img_id = torch.tensor([image_id])
  # Annotation is in dictionary format
  target = {}
  target["boxes"] = boxes
  target["labels"] = labels
  target["image_id"] = img_id

  return target

In [None]:
class MaskDataset(object):
    def __init__(self, img_path, annot_roi_path, transforms):
      self.img_path = img_path
      self.annot_roi_path = annot_roi_path
      self.transforms = transforms
      self.imgs = list(sorted(os.listdir(img_path)))

    def __getitem__(self, idx):
      files = os.listdir(self.img_path)
      file_image = files[idx]
      file_image_name = file_image.split('.')[0]
      file_annotation = file_image_name + '.zip'
      img_path = os.path.join(self.img_path, file_image)
      annotation_path = os.path.join(self.annot_roi_path, file_annotation)
      img = Image.open(img_path).convert("RGB")
      target = generate_target(idx, annotation_path)

      if self.transforms is not None:
        img = self.transforms(img)

      return img, target

    def __len__(self):
      return len(self.imgs)

In [None]:
data_transform = transforms.Compose([transforms.ToTensor(), ])

In [None]:
def collate_fn(batch):
  return tuple(zip(*batch))

#The dataset used is a public dataset to detect the mask, the original anntation format is xml.
#Converted the annotation files to ImageJ ROI format For showing how to fine tune object detection model on ImageJ ROI (Custom Dataset)

images_path = "/content/drive/MyDrive/object_detection/archive/images/"
annotations_path = "/content/drive/MyDrive/object_detection/archive/annotations_imagej_roi/"

dataset = MaskDataset(images_path,annotations_path,data_transform)

# Split the dataset，adjust the validation and test set size as needed
#val_size = int(0.2 * len(dataset))
test_size = int(0.01 * len(dataset))
train_size = len(dataset) - test_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

print('length of training dataset: ', len(train_dataset), '\n', 'length of testing dataset: ', len(test_dataset))

#Create Data Loaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True,  num_workers=2, collate_fn=collate_fn)
#val_loader = DataLoader(val_dataset, batch_size=4, shuffle=True,  num_workers=2, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False,  num_workers=2, collate_fn=collate_fn)


In [None]:
def get_model_objective_detection(num_classes):
  # load an instance object detection model pre-trained
  model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(pretrained=True)
  # get number of input features for the classifier
  in_features = model.roi_heads.box_predictor.cls_score.in_features
  # replace the pre-trained head with a new one
  model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

  return model

In [None]:
model = get_model_objective_detection(4)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
#Define the visualisation function to check whether the input is right
def plot_image(imgs, annotations):
  fig,ax = plt.subplots(1)
  fig.set_size_inches(5,5)
  img = imgs.cpu().numpy()
  boxes = annotations["boxes"].detach().cpu().numpy()
  labels = annotations["labels"].detach().cpu().numpy()
  # Display the image
  ax.imshow(np.transpose(img,(1, 2, 0)))

  for box,label in zip(boxes, labels):
    xmin, ymin, xmax, ymax = box
    # Create a Rectangle patch
    if label == 1:
      ax.add_patch(patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='r',facecolor='none'))
      class_name = class_names[label]
      ax.text(xmin,(ymin-15),class_name,verticalalignment='top',color='r',fontsize=5,weight='bold')
    elif label == 2:
      ax.add_patch(patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='b',facecolor='none'))
      class_name = class_names[label]
      ax.text(xmin,(ymin-15),class_name,verticalalignment='top',color='b',fontsize=5,weight='bold')
    else:
      ax.add_patch(patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='g',facecolor='none'))
      class_name = class_names[label]
      ax.text(xmin,(ymin-15),class_name,verticalalignment='top',color='g',fontsize=5,weight='bold')

  plt.show()

In [None]:
#Select one image and its annotation to check whether the input is right by visual way
for imgs, annotations in train_loader:
  imgs = list(img.to(device) for img in imgs)
  annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
  print(imgs[0].shape,'\n',annotations)
  break

class_names = ['background', 'with_mask', 'mask_weared_incorrect', 'without_mask']
print("Target")
plot_image(imgs[0], annotations[0])

In [None]:
#Start training
model.to(device)

num_epochs = 150

# parameters
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

len_dataloader = len(train_loader)

history = []

# Initialize variables to track best validation loss
best_loss = float('inf')
# Define a path to save the best model
best_model_path = '/content/drive/MyDrive/object_detection/archive/model_1.pth'

for epoch in range(num_epochs):
    epoch_start = time.time()
    model.train()
    i = 0
    training_loss = 0
    #valid_loss = 0
    for imgs, annotations in train_loader:
      i += 1
      imgs = list(img.to(device) for img in imgs)
      annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
      loss_dict = model([imgs[0]], [annotations[0]])
      losses = sum(loss for loss in loss_dict.values())
      training_loss += losses
      optimizer.zero_grad()
      losses.backward()
      optimizer.step()
      #print(f'Iteration: {i}/{len_dataloader}, Loss: {losses}')
    #with torch.no_grad():
      # Set to evaluation mode
      #model.eval()
    # Validation loop
      #j = 0
      #for imgs, annotations in val_loader:
        #j += 1
        #imgs = list(img.to(device) for img in imgs)
        #annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
        #valid_loss_dict = model([imgs[0]], [annotations[0]])
        #losses = sum(loss for loss in valid_loss_dict.values())
        #valid_loss += losses
    # Update the learning rate based on valid_loss
    #scheduler.step()
    history.append([training_loss])
      # Save the best performance model in validation with least valid loss over all the epochs
    if training_loss < best_loss:
      best_loss = training_loss
      torch.save(model.state_dict(), best_model_path)
    epoch_end = time.time()
    #Print the training and validation history
    print("Epoch: {}/{},  Training: Loss: {:.4f},  Time: {:.4f}s".format(epoch+1, num_epochs, training_loss, epoch_end-epoch_start))

In [None]:
print(best_loss)
# Visualize the training history
train_loss = [row[0].detach().cpu() for row in history]
#valid_loss = [row[1] for row in history]
plt.plot(train_loss, label='Train loss')
#plt.plot(valid_loss,label='Valid loss')
plt.xlabel('Epoch Number')
plt.ylabel('Loss')
plt.title('Loss curve for training')
plt.grid(True)
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.show()

In [None]:
# Import the best model saved previous
best_model_path = '/content/drive/MyDrive/object_detection/archive/model_1.pth'
model.load_state_dict(torch.load(best_model_path))

In [None]:
#Predict the test dataset
all_imgs = []
all_annotations = []

for imgs, annotations in test_loader:
  imgs = list(img.to(device) for img in imgs)
  annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
  all_imgs.extend(imgs)
  all_annotations.extend(annotations)

print(all_annotations)

model.eval()
model.to(device)
preds = model(all_imgs)
preds

In [None]:
#Define the function to visualise the truth and prediction of test dataset
class_names = ['background', 'with_mask', 'mask_weared_incorrect', 'without_mask']
def plot_pred_truth_image(imgs, preds, annotations):
  fig,ax = plt.subplots(1)
  fig.set_size_inches(5,5)
  img = imgs.cpu().numpy()
  boxes_pred = preds["boxes"].detach().cpu().numpy()
  labels_pred = preds["labels"].detach().cpu().numpy()
  scores_pred = preds["scores"].detach().cpu().numpy()
  boxes = annotations["boxes"].detach().cpu().numpy()
  labels = annotations["labels"].detach().cpu().numpy()
  # Display the image
  ax.imshow(np.transpose(img,(1, 2, 0)))

  for box,label,score in zip(boxes_pred, labels_pred, scores_pred):
    if score > 0 :
      xmin, ymin, xmax, ymax = box
    # Create a Rectangle patch
      if label == 1:
        ax.add_patch(patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='y',facecolor='none'))
        class_name = class_names[label]
        ax.text(xmin,(ymin-20),'{}{:.2f}'.format(class_name,score),verticalalignment='top',color='y',fontsize=6,weight='bold')
      elif label == 2:
        ax.add_patch(patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='y',facecolor='none'))
        class_name = class_names[label]
        ax.text(xmin,(ymin-20),'{}{:.2f}'.format(class_name,score),verticalalignment='top',color='y',fontsize=6,weight='bold')
      else:
        ax.add_patch(patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='b',facecolor='none'))
        class_name = class_names[label]
        ax.text(xmin,(ymin-20),'{}{:.2f}'.format(class_name,score),verticalalignment='top',color='b',fontsize=6,weight='bold')

  for box,label in zip(boxes, labels):
    xmin, ymin, xmax, ymax = box
    # Create a Rectangle patch
    if label == 1:
      ax.add_patch(patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='r',facecolor='none'))
      class_name = class_names[label]
      ax.text(xmin,(ymin-10),'',verticalalignment='top',color='r',fontsize=5,weight='bold')
    elif label == 2:
      ax.add_patch(patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='r',facecolor='none'))
      class_name = class_names[label]
      ax.text(xmin,(ymin-10),'',verticalalignment='top',color='r',fontsize=5,weight='bold')
    else:
      ax.add_patch(patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='g',facecolor='none'))
      class_name = class_names[label]
      ax.text(xmin,(ymin-10),'',verticalalignment='top',color='g',fontsize=5,weight='bold')

  plt.show()

In [None]:
for i in range(len(preds)):
  print(f'\n Prediction {i+1}')
  plot_pred_truth_image(all_imgs[i], preds[i], all_annotations[i])

In [None]:
#Import the prediction evaluation package to evaluate of test dataset, github address: https://github.com/rafaelpadilla/Object-Detection-Metrics
import sys
# Add the directory containing your module to the Python path
sys.path.append('/content/drive/MyDrive/Colab_Notebooks')

from _init_paths import *
from utils import *
from Evaluator import *
from BoundingBox import *
from BoundingBoxes import *


In [None]:
#Define the function to get the bounding boxes from truth and prediction annotation for calculating the metrics in next step.
def getBoundingBoxes(annotations, preds):
  """Read txt files containing bounding boxes (ground truth and detections)."""
  allBoundingBoxes = BoundingBoxes()

  for i in range(len(annotations)):
    nameOfImage = 'img'+ str(i)
    annotation = annotations[i]
    boxes = annotation["boxes"].detach().cpu().numpy()
    labels = annotation["labels"].detach().cpu().numpy()
    for box,label in zip(boxes, labels):
      xmin, ymin, xmax, ymax = box
      idClass = class_names[label]  # class
      x = xmin
      y = ymin
      w = xmax-xmin
      h = ymax-ymin
      bb = BoundingBox(nameOfImage,idClass,x,y,w,h,CoordinatesType.Absolute, (w, h),BBType.GroundTruth,format=BBFormat.XYWH)
      allBoundingBoxes.addBoundingBox(bb)

  for i in range(len(preds)):
    nameOfImage = 'img'+ str(i)
    pred = preds[i]
    boxes_pred = pred["boxes"].detach().cpu().numpy()
    labels_pred = pred["labels"].detach().cpu().numpy()
    scores_pred = pred["scores"].detach().cpu().numpy()
    for box,label,score in zip(boxes_pred, labels_pred, scores_pred):
      #every box in each image
      if score > 0 :
        xmin, ymin, xmax, ymax = box
        idClass = class_names[label]  # class
        x = xmin
        y = ymin
        w = xmax-xmin
        h = ymax-ymin
        bb = BoundingBox(nameOfImage,idClass,x,y,w,h,CoordinatesType.Absolute, (w, h),BBType.Detected,score,format=BBFormat.XYWH)
        allBoundingBoxes.addBoundingBox(bb)
  return allBoundingBoxes


In [None]:
# Read bounding boxes (ground truth and detections)
boundingboxes = getBoundingBoxes(all_annotations, preds)
boundingboxes

# Uncomment the line below to generate images based on the bounding boxes
# createImages(dictGroundTruth, dictDetected)
# Create an evaluator object in order to obtain the metrics
evaluator = Evaluator()
##############################################################
# VOC PASCAL Metrics
##############################################################
# Plot Precision x Recall curve
evaluator.PlotPrecisionRecallCurve(
    boundingboxes,  # Object containing all bounding boxes (ground truths and detections)
    IOUThreshold=0.7,  # IOU threshold
    method=MethodAveragePrecision.EveryPointInterpolation,  # As the official matlab code
    showAP=True,  # Show Average recision in the title of the plot
    showInterpolatedPrecision=True)  # Plot the interpolated precision curve
# Get metrics with PASCAL VOC metrics
metricsPerClass = evaluator.GetPascalVOCMetrics(
    boundingboxes,  # Object containing all bounding boxes (ground truths and detections)
    IOUThreshold=0.7,  # IOU threshold
    method=MethodAveragePrecision.EveryPointInterpolation)  # As the official matlab code
print("Average precision values per class:\n")
# Loop through classes to obtain their metrics
for mc in metricsPerClass:
    # Get metric values per each class
    c = mc['class']
    precision = mc['precision']
    recall = mc['recall']
    average_precision = mc['AP']
    ipre = mc['interpolated precision']
    irec = mc['interpolated recall']
    # Print AP per class
    print('%s: %f' % (c, average_precision))