In [1]:
# !pip install torch
# !pip install torchvision

In [2]:
import os
import xml.etree.ElementTree as ET
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.transforms import functional as F
from PIL import Image
from torchvision.ops import nms
from sklearn.metrics import precision_recall_curve, auc
import copy
import gc
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
import matplotlib.patches as patches


In [3]:

# set direction and set parameters
data_dir = './Data'
image_dir = os.path.join(data_dir, 'images')
label_dir = os.path.join(data_dir, 'xmls')
use_subset = False
sub_percentage = 0.1
train_batch_size = 30
test_batch_size = 1
step_size = 50
gamma = 0.001
lr = 0.001
nms_step_size = 0.1
nms_step = 3
weight_decay= 0.001
num_epochs = 30
nms_iou_thresh = 0.4
mAP_iou_threshold= 0.5
score_thresh_init = 0.0
num_classes = 5
local_model_path = './faster_rcnn_model.pth'
patience = 5
num_worker = os.cpu_count()

In [4]:

# # set direction and set parameters
# data_dir = './Data'
# image_dir = os.path.join(data_dir, 'images')
# label_dir = os.path.join(data_dir, 'xmls')
# use_subset = True
# sub_percentage = 0.1
# train_batch_size = 4
# test_batch_size = 1
# step_size = 20
# gamma = 0.005
# lr = 0.005
# weight_decay= 0.001
# num_epochs = 10
# nms_iou_thresh = 0.01
# mAP_iou_threshold = 0.5
# score_thresh = 0.4
# num_classes = 5
# local_model_path = './faster_rcnn_model.pth'
# patience = 2


In [5]:

# dataset class
class RDD2022Dataset(Dataset):
    def __init__(self, image_dir, label_dir, transforms=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transforms = transforms
        self.classes = ['Background', 'D00', 'D10', 'D20', 'D40']
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        print(self.class_to_idx)

        self.images = list(sorted(os.listdir(image_dir)))
        self.labels = list(sorted(os.listdir(label_dir)))

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        # load image and label
        img_path = os.path.join(self.image_dir, self.images[idx])
        img = Image.open(img_path).convert("RGB")
        label_path = os.path.join(self.label_dir, self.labels[idx])
        tree = ET.parse(label_path)
        root = tree.getroot()
        boxes = []
        labels = []
        for obj in root.iter('object'):
            cls = obj.findtext('name')
            labels.append(self.class_to_idx[cls])
            xmlbox = obj.find('bndbox')
            bbox_coords = ['xmin', 'ymin', 'xmax', 'ymax']
            boxes.append([int(float(xmlbox.findtext(tag))) for tag in bbox_coords])
        # if have label
        if len(boxes) > 0:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # if it doesn't have label
        else:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros(0, dtype=torch.int64)
            area = torch.zeros(0, dtype=torch.float32)

        image_id = torch.tensor([idx])
        iscrowd = torch.zeros(len(boxes), dtype=torch.int64)
        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": image_id,
            "area": area,
            "iscrowd": iscrowd
        }
        if self.transforms:
            img, target = self.transforms(img, target)
        return img, target


In [6]:

def get_transform():
    def transform(img, target):
        img = F.to_tensor(img)
        return img, target
    return transform

def get_model(num_classes):
    model = fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

# Create dataset and data loader
train_dataset = RDD2022Dataset(
    os.path.join(image_dir, 'train'),
    os.path.join(label_dir, 'train'),
    transforms=get_transform()
)

val_dataset = RDD2022Dataset(
    os.path.join(image_dir, 'val'),
    os.path.join(label_dir, 'val'),
    transforms=get_transform()
)

def create_subset(dataset, subset_size):
    # Check subset condition
    if isinstance(subset_size, float) and subset_size > 0 and subset_size < 1:
        subset_size = int(len(dataset) * subset_size)
    elif isinstance(subset_size, int) and subset_size > 0 and subset_size < len(dataset):
        subset_size = subset_size
    else:
        raise ValueError("subset_size must be a positive integer or a float between 0 and 1.")
    # Generate a random subset
    indices = torch.randperm(len(dataset))[:subset_size]
    subset = Subset(dataset, indices)
    return subset

# DataLoader setup
# if using subset
if use_subset:
    train_subset_size = sub_percentage 
    val_subset_size = sub_percentage
    train_subset = create_subset(train_dataset, train_subset_size)
    val_subset = create_subset(val_dataset, val_subset_size)
    train_loader = DataLoader(train_subset, batch_size=train_batch_size, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
    val_loader = DataLoader(val_subset, batch_size=train_batch_size, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))
    print('Using subset')
# if using full dataset
else:
    train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
    val_loader = DataLoader(val_dataset, batch_size=train_batch_size, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))
    print('Using full dataset')

# Clear CUDA cache
def clear_memory():
    torch.cuda.empty_cache()
    gc.collect()


{'Background': 0, 'D00': 1, 'D10': 2, 'D20': 3, 'D40': 4}
{'Background': 0, 'D00': 1, 'D10': 2, 'D20': 3, 'D40': 4}
Using full dataset


In [7]:

# calculate Intersection over Union (IoU)
def calculate_iou(gt_box, pred_box):
    xi1 = max(gt_box[0], pred_box[0])
    yi1 = max(gt_box[1], pred_box[1])
    xi2 = min(gt_box[2], pred_box[2])
    yi2 = min(gt_box[3], pred_box[3])
    inter_area = max(xi2 - xi1, 0) * max(yi2 - yi1, 0)
    
    gt_area = (gt_box[2] - gt_box[0]) * (gt_box[3] - gt_box[1])
    pred_area = (pred_box[2] - pred_box[0]) * (pred_box[3] - pred_box[1])
    union_area = gt_area + pred_area - inter_area
    
    if union_area == 0:
        return 0
    iou = inter_area / union_area
    return iou

# apply Non-Max Suppression (NMS) to filter predictions
def apply_nms(orig_prediction, epoch, iou_thresh=nms_iou_thresh, score_thresh_init=score_thresh_init, nms_step = nms_step, step_size = nms_step_size):
    # if no detections to process
    if orig_prediction['scores'].nelement() == 0:
        return {
            'boxes': torch.empty((0, 4), dtype=torch.float32),
            'scores': torch.empty((0,), dtype=torch.float32),
            'labels': torch.empty((0,), dtype=torch.int64)
        }
    # apply NMS
    boxes = orig_prediction['boxes'].float()
    scores = orig_prediction['scores'].float()
    if epoch//nms_step <= 5:
        score_thresh = score_thresh_init + (epoch//nms_step)*nms_step_size
    else:
        score_thresh = 0.5
    high_score_mask = scores > score_thresh
    boxes = boxes[high_score_mask]
    scores = scores[high_score_mask]
    labels = orig_prediction['labels'][high_score_mask]
    keep = nms(boxes, scores, iou_thresh)
    return {'boxes': boxes[keep], 'scores': scores[keep],'labels': labels[keep]}

# update true positives, false positives, and false negatives for each class
def compute_tp_fp_fn_for_image(epoch_index, gt_boxes, gt_labels, pred_boxes, pred_scores, pred_labels, num_classes, iou_threshold= mAP_iou_threshold):
    # iterate through each class (ignoring the background class)
    for cls in range(1, num_classes):
        # filter ground truth and predictions for the current class
        gt_idx = np.where(gt_labels == cls)[0]
        pred_idx = np.where(pred_labels == cls)[0]

        if len(gt_idx) == 0 and len(pred_idx) == 0:
            continue
        elif len(gt_idx) == 0:
            epoch_index[cls-1,1] += len(pred_boxes[pred_idx])
            continue
        elif len(pred_idx) == 0:
            epoch_index[cls-1,2] += len(gt_boxes[gt_idx])
            continue

        class_gt_boxes = gt_boxes[gt_idx]
        class_pred_boxes = pred_boxes[pred_idx]
        class_pred_scores = pred_scores[pred_idx]

        # sort predictions by scores in descending order
        sorted_indices = np.argsort(-class_pred_scores)
        class_pred_boxes = class_pred_boxes[sorted_indices]

        # initialize counters for true positives, false positives, and false negatives
        tp = 0
        detected = np.zeros(len(class_gt_boxes), dtype=bool)

        # check each predicted box against all GT boxes of the same class
        for pred_box in class_pred_boxes:
            ious = [calculate_iou(pred_box, gt_box) for gt_box in class_gt_boxes]
            max_iou = max(ious)
            max_gt_idx = np.argmax(ious)

            if max_iou >= iou_threshold and not detected[max_gt_idx]:
                tp += 1
                detected[max_gt_idx] = True

        # update class metrics
        epoch_index[cls-1,0] += tp
        epoch_index[cls-1,1] += (len(class_pred_boxes) - tp)  # All non-TPs are FPs
        epoch_index[cls-1,2] += (len(class_gt_boxes) - np.sum(detected))  # GTs not detected are FNs

    return epoch_index


In [8]:

# evaluate the model and calculate mAP along with precision, recall, and F1 scores
def evaluate_or_test_model(model, data_loader, device, epoch):
    model.eval()
    # total_val_loss = 0
    epoch_index = np.zeros((num_classes-1, 3))
    all_detections = [[] for _ in range(num_classes-1)]
    all_annotations = [[] for _ in range(num_classes-1)]

    with torch.no_grad():
        # for each evaluation batch
        for images, targets in tqdm(data_loader, desc="Evaluating", leave=False):
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            outputs = model(images)
            for i, output in enumerate(outputs):
                nms_output = apply_nms(output, epoch)
                pred_boxes = nms_output['boxes'].cpu().numpy()
                pred_labels = nms_output['labels'].cpu().numpy()
                pred_scores = nms_output['scores'].cpu().numpy()
                gt_boxes = targets[i]['boxes'].cpu().numpy()
                gt_labels = targets[i]['labels'].cpu().numpy()
                epoch_index = compute_tp_fp_fn_for_image(epoch_index, gt_boxes, gt_labels, pred_boxes, pred_scores, pred_labels, num_classes)

                # Store detections and annotations for mAP calculation
                for label in range(1, num_classes):
                    pred_indices = np.where(pred_labels == label)[0]
                    gt_indices = np.where(gt_labels == label)[0]
                    all_detections[label-1].extend([(box, score) for box, score in zip(pred_boxes[pred_indices], pred_scores[pred_indices])])
                    all_annotations[label-1].extend(gt_boxes[gt_indices])

            # Compute losses
            # if targets:
            #     model.train()
            #     loss_dict = model(images, targets)
            #     losses = sum(loss for loss in loss_dict.values())
            #     total_val_loss += losses.item()
            #     model.eval()

            del images, targets, outputs
            clear_memory()

    # avg_loss = total_val_loss / len(data_loader)

    # calculate mAP for each class
    avg_precisions = []
    for class_detections, class_annotations in zip(all_detections, all_annotations):
        if not class_detections or not class_annotations:
            continue
        # sort detections by decreasing confidence
        sorted_detections = sorted(class_detections, key=lambda x: x[1], reverse=True)
        tp = np.zeros(len(sorted_detections))
        fp = np.zeros(len(sorted_detections))
        matched = np.zeros(len(class_annotations), dtype=bool)  # Tracks which GT boxes have been matched

        for d_idx, (bbox, score) in enumerate(sorted_detections):
            ious = [calculate_iou(bbox, gt_bbox) for gt_bbox in class_annotations]
            best_iou = max(ious) if ious else 0
            best_idx = np.argmax(ious) if ious else -1

            if best_iou >= mAP_iou_threshold:
                if not matched[best_idx]:
                    tp[d_idx] = 1
                    matched[best_idx] = True
                else:
                    fp[d_idx] = 1
            else:
                fp[d_idx] = 1

        # accumulate true positives and false positives for precision-recall calculation
        acc_tp = np.cumsum(tp)
        acc_fp = np.cumsum(fp)
        recall = acc_tp / len(class_annotations) if len(class_annotations) > 0 else np.zeros_like(acc_tp)
        precision = acc_tp / (acc_tp + acc_fp) if (acc_tp + acc_fp).any() else np.zeros_like(acc_tp)

        # calculate area under the precision-recall curve (AUC)
        ap = auc(recall, precision)
        avg_precisions.append(ap)

    # compute the mean of average precisions across all classes
    mAP50 = np.mean(avg_precisions) if avg_precisions else 0

    # calculate overall precision, recall, and F1 score
    tp_total = epoch_index[:, 0]
    fp_total = epoch_index[:, 1]
    fn_total = epoch_index[:, 2]
    recall = tp_total / (tp_total + fn_total + 1e-6)
    precision = tp_total / (tp_total + fp_total + 1e-6)
    f1_score = 2 * (precision * recall) / (precision + recall + 1e-6)

    # return avg_loss, mean_ap, precision, recall, f1_score

    return mAP50, precision, recall, f1_score


In [None]:

# train the model
def train_model(model, train_loader, val_loader, optimizer, device, num_epochs=num_epochs, patience=patience):
    best_mAP50 = 0
    last_mAP50 = 0
    best_model_wts = None
    epochs_no_improve = 0
    
    for epoch in range(num_epochs):
        total_train_loss = 0
        for images, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=False):
            model.train()
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            optimizer.zero_grad()
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            total_train_loss += losses.item()
            losses.backward()
            optimizer.step()
            del images, targets, loss_dict 
            clear_memory()
        mAP50, precision, recall, f1_score = evaluate_or_test_model(model, val_loader, device, epoch)
        # print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {total_train_loss/len(train_loader):.4f}, '
        #       f'Val Loss: {val_loss:.4f}, mAP: {mean_ap:.4f}')
        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {total_train_loss/len(train_loader):.4f}, mAP50: {mAP50:.4f}')
        print(f'F1 per group: {f1_score}, Recall per group: {recall}, Precision per group: {precision}')
        print(f'General F1: {sum(f1_score) / len(f1_score)}, General Recall: {sum(recall) / len(recall)}, General Precision: {sum(precision) / len(precision)}')
        lr_scheduler.step()

        
        if mAP50 >= best_mAP50:
            best_mAP50 = mAP50
            last_mAP50 = mAP50
            best_model_wts = copy.deepcopy(model.state_dict())
            epochs_no_improve = 0
        elif mAP50 >= last_mAP50:
            last_mAP50 = mAP50
            epochs_no_improve -= 1
        else:
            last_mAP50 = mAP50
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print("Early stopping triggered")
                break
        clear_memory()

    model.load_state_dict(best_model_wts)
    return model, best_mAP50

# Initialize and train the model
model = get_model(len(train_dataset.classes))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
trained_model, best_mAP50 = train_model(model, train_loader, val_loader, optimizer, device, num_epochs=num_epochs)


                                                             

Epoch 1/30, Train Loss: 0.4199, mAP50: 0.3865
F1 per group: [0.17325309 0.12802071 0.12346902 0.08570451], Recall per group: [0.63466397 0.59855636 0.61901763 0.13477366], Precision per group: [0.10031949 0.07167553 0.06857342 0.06282974]
General F1: 0.12761183101361503, General Recall: 0.4967529057635237, General Precision: 0.07584954464881952


                                                             

Epoch 2/30, Train Loss: 0.3633, mAP50: 0.4569
F1 per group: [0.20842904 0.19999974 0.21292906 0.08395783], Recall per group: [0.6859525  0.66240977 0.64609572 0.40329218], Precision per group: [0.12288404 0.11778063 0.12746925 0.04685632]
General F1: 0.17632891879578327, General Recall: 0.5994375428007923, General Precision: 0.10374756120092359


                                                             

Epoch 3/30, Train Loss: 0.3478, mAP50: 0.4882
F1 per group: [0.16117423 0.17847894 0.21714498 0.09501495], Recall per group: [0.72536635 0.70960577 0.65239295 0.50102881], Precision per group: [0.09065934 0.10207668 0.13024893 0.0524841 ]
General F1: 0.16295327152375683, General Recall: 0.647098468348369, General Precision: 0.0938672632970964


                                                             

Epoch 4/30, Train Loss: 0.3391, mAP50: 0.4922
F1 per group: [0.28117905 0.27327926 0.30367071 0.19088462], Recall per group: [0.69151086 0.68684064 0.6511335  0.48045267], Precision per group: [0.1764668  0.17057363 0.19800843 0.11910227]
General F1: 0.2622534094499187, General Recall: 0.6274844207138857, General Precision: 0.16603778152286935


                                                             

Epoch 5/30, Train Loss: 0.3334, mAP50: 0.5074
F1 per group: [0.30814881 0.26743788 0.30989475 0.17326212], Recall per group: [0.6816574  0.70571904 0.68828715 0.53600823], Precision per group: [0.19907032 0.16497923 0.19996341 0.10333201]
General F1: 0.26468588984973795, General Recall: 0.6529179575649455, General Precision: 0.16683624232214916


                                                             

Epoch 6/30, Train Loss: 0.3270, mAP50: 0.5127
F1 per group: [0.34663868 0.32972057 0.3492881  0.21644803], Recall per group: [0.676857   0.6757357  0.68010076 0.51851852], Precision per group: [0.23297678 0.21806128 0.23498695 0.13677069]
General F1: 0.31052384688284923, General Recall: 0.6378029933874323, General Precision: 0.20569892423673625


                                                             

Epoch 7/30, Train Loss: 0.3221, mAP50: 0.4983
F1 per group: [0.41209363 0.37876346 0.47259906 0.30703441], Recall per group: [0.65083375 0.66352027 0.63539043 0.48045267], Precision per group: [0.30149813 0.2650255  0.37621178 0.22560386]
General F1: 0.39262264141570846, General Recall: 0.6075492806553666, General Precision: 0.292084819632578


Epoch 8:  34%|███▍      | 190/555 [12:07<23:51,  3.92s/it]

In [None]:
def save_model(model, path):
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")
save_model_path = f'./faster_rcnn_lr{lr}_bs{train_batch_size}_epochs{num_epochs}_mAP{best_mAP50}.pth'
# Example usage:
model_save_path = save_model_path
save_model(trained_model, model_save_path)


In [None]:
def visualize_comparison_three_plots(image_path, label_path, model, device, dataset_classes, transform=None):
    # Load and prepare image
    image = Image.open(image_path).convert("RGB")
    if transform:
        image_tensor, _ = transform(image, {})  # Assuming the transform function returns a tensor and a target
    else:
        image_tensor = F.to_tensor(image)  # Convert to tensor without any additional transformation

    image_tensor = image_tensor.unsqueeze(0).to(device)  # Add batch dimension and transfer to device

    # Load true labels and boxes
    tree = ET.parse(label_path)
    root = tree.getroot()
    true_boxes = []
    true_labels = []
    for obj in root.iter('object'):
        cls = obj.find('name').text
        xmlbox = obj.find('bndbox')
        xmin = int(float(xmlbox.find('xmin').text))
        ymin = int(float(xmlbox.find('ymin').text))
        xmax = int(float(xmlbox.find('xmax').text))
        ymax = int(float(xmlbox.find('ymax').text))
        true_boxes.append([xmin, ymin, xmax, ymax])
        true_labels.append(dataset_classes.index(cls))

    # Model inference
    model.eval()
    with torch.no_grad():
        predictions = model(image_tensor)
    predictions = apply_nms(predictions[0], 30)  # Assuming apply_nms is defined correctly

    # Prepare image for plotting
    image_np = np.array(image)

    # Set up three subplots
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 18))
    ax1.imshow(image_np)
    ax2.imshow(image_np)
    ax3.imshow(image_np)

    # Draw predicted boxes and labels on the second subplot
    pred_boxes = predictions['boxes'].cpu().numpy()
    pred_scores = predictions['scores'].cpu().numpy()
    pred_labels = predictions['labels'].cpu().numpy()

    for box, score, label in zip(pred_boxes, pred_scores, pred_labels):
        x1, y1, x2, y2 = box
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
        ax2.add_patch(rect)
        ax2.text(x1, y1, f'Pred: {dataset_classes[label]} {score:.2f}', color='white', fontsize=12,
                 bbox=dict(facecolor='red', alpha=0.5))

    # Draw true boxes and labels on the third subplot
    for box, label in zip(true_boxes, true_labels):
        x1, y1, x2, y2 = box
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='g', facecolor='none')
        ax3.add_patch(rect)
        ax3.text(x1, y2, f'True: {dataset_classes[label]}', color='white', fontsize=12,
                 bbox=dict(facecolor='green', alpha=0.5))

    ax1.set_title('Original Image')
    ax2.set_title('Predictions')
    ax3.set_title('Ground Truth')
    plt.show()

# Example usage
image_path = './Data/images/val/Norway_000830.jpg'
label_path = './Data/xmls/val/Norway_000830.xml'
visualize_comparison_three_plots(image_path, label_path, model, device, train_dataset.classes, transform=get_transform())


In [None]:
# Example usage
image_path = './Data/images/val/China_Drone_000086.jpg'
label_path = './Data/xmls/val/China_Drone_000086.xml'
visualize_comparison_three_plots(image_path, label_path, model, device, train_dataset.classes, transform=get_transform())


In [None]:
# Example usage
image_path = './Data/images/val/Japan_003127.jpg'
label_path = './Data/xmls/val/Japan_003127.xml'
visualize_comparison_three_plots(image_path, label_path, model, device, train_dataset.classes, transform=get_transform())


In [None]:
import os
import torch
import xml.etree.ElementTree as ET
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import functional as F
from PIL import Image
local_model_path = f'./faster_rcnn_lr{lr}_bs{train_batch_size}_epochs{num_epochs}_mAP{best_mAP50}.pth'
test_image_dir = os.path.join(image_dir, 'test')
test_label_dir = os.path.join(label_dir, 'test')
model_path = local_model_path
test_batch_size = 4

class TestDataset(Dataset):
    def __init__(self, image_dir, label_dir, transforms=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transforms = transforms
        self.images = list(sorted(os.listdir(image_dir)))
        self.labels = list(sorted(os.listdir(label_dir)))
        self.class_to_idx = {'Background': 0, 'D00': 1, 'D10': 2, 'D20': 3, 'D40': 4}  # Update as necessary

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        label_path = os.path.join(self.label_dir, self.labels[idx])
        img = Image.open(img_path).convert("RGB")

        tree = ET.parse(label_path)
        root = tree.getroot()
        boxes = []
        labels = []
        for obj in root.iter('object'):
            cls = obj.findtext('name')
            labels.append(self.class_to_idx[cls])
            xmlbox = obj.find('bndbox')
            boxes.append([int(xmlbox.findtext(tag)) for tag in ['xmin', 'ymin', 'xmax', 'ymax']])

        boxes = torch.as_tensor(boxes, dtype=torch.float32) if boxes else torch.zeros((0, 4), dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64) if labels else torch.zeros(0, dtype=torch.int64)
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) if boxes.nelement() != 0 else torch.zeros(0, dtype=torch.float32)

        target = {'boxes': boxes, 'labels': labels, 'image_id': torch.tensor([idx]), 'area': area, 'iscrowd': torch.zeros(len(boxes), dtype=torch.int64)}

        if self.transforms:
            img = self.transforms(img)

        return img, target

def get_transform():
    def transform(img):
        return F.to_tensor(img)
    return transform

# Load the model
def load_model(model_path, num_classes):
    model = get_model(num_classes)  # Ensure get_model is defined or imported
    model.load_state_dict(torch.load(model_path))
    model.to(device)
    model.eval()
    return model

test_dataset = TestDataset(test_image_dir, test_label_dir, transforms=get_transform())
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

model = load_model(model_path, num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mean_ap, precision, recall, f1_score = evaluate_or_test_model(model, test_loader, device, 1)
print(f'Test mAP50: {mean_ap:.4f}')
print(f'F1 per group: {f1_score}, Recall per group: {recall}, Precision per group: {precision}')
print(f'General F1: {sum(f1_score) / len(f1_score)}, General Recall: {sum(recall) / len(recall)}, General Precision: {sum(precision) / len(precision)}')
