# Faster R-CNN Implementation

## Transforms / Model loading

In [1]:
import matplotlib.pyplot as plt, numpy as np, os, torch, random, cv2, json
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torchvision import models
from torchvision.transforms import v2 as transforms
import torchvision
from torchvision import models
from PIL import Image
from tqdm import tqdm
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

num_classes = 12
weights = torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1
model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2( weights=weights)
# Replace the classifier with a new one for your number of classes
transforms1 = weights.transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(transforms1)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

data_aug = transforms.Compose([
    transforms.ToImage(),
    transforms.RandomResizedCrop(224, scale=(0.9, 1.0), ratio=(0.95, 1.05)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomRotation(degrees=5),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

data_in = transforms.Compose([
    transforms.ToImage(),
    transforms.Resize((256, 256)),
    transforms.CenterCrop((224, 224)),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])



<class 'torchvision.transforms._presets.ObjectDetection'>


## DataLoader

In [2]:
def convert_bbox_format(bbox):
    """
    Convert [x, y, width, height] to [x_min, y_min, x_max, y_max]
    
    Args:
        bbox (list): [x_center, y_center, width, height]
    
    Returns:
        list: [x_min, y_min, x_max, y_max]
    """
    x, y, w, h = bbox
    x_min = x
    y_min = y
    x_max = x + w
    y_max = y + h
    return [x_min, y_min, x_max, y_max]
class ChessDataset(Dataset):
    def __init__(self, root_dir, partition, transform=None):
        self.anns = json.load(open(os.path.join(root_dir, 'annotations.json')))
        self.categories = [c['name'] for c in self.anns['categories']]
        self.root = root_dir
        self.ids = []
        self.file_names = []
        for x in self.anns['images']:
            self.file_names.append(x['path'])
            self.ids.append(x['id'])
        self.file_names = np.asarray(self.file_names)
        self.ids = np.asarray(self.ids)
        # create a list of size num_images, each element is a list of pieces
        self.boardLabels = [[] for _ in range(len(self.ids))]
        self.boardBB = [[] for _ in range(len(self.ids))]
        for piece in self.anns['annotations']['pieces']:
            idx = np.where(self.ids == piece['image_id'])[0][0]
            if "bbox" in piece.keys():
                bbox = convert_bbox_format(piece['bbox'])
                self.boardBB[idx].append(bbox)
                self.boardLabels[idx].append(piece['category_id'])
        if partition == 'train':
            self.split_ids = np.asarray(self.anns['splits']["chessred2k"]['train']['image_ids']).astype(int)
        elif partition == 'valid':
            self.split_ids = np.asarray(self.anns['splits']["chessred2k"]['val']['image_ids']).astype(int)
        else:
            self.split_ids = np.asarray(self.anns['splits']["chessred2k"]['test']['image_ids']).astype(int)
        self.split_ids = self.split_ids#[:5]
        intersect = np.isin(self.ids, self.split_ids)
        self.split_ids = np.where(intersect)[0]
        self.file_names = [self.file_names[i] for i in self.split_ids]
        self.boardBB = [self.boardBB[i] for i in self.split_ids]
        self.boardLabels = [self.boardLabels[i] for i in self.split_ids]
        #self.num_pieces = F.one_hot(self.num_pieces.long()-1, 32)
        self.ids = self.ids[self.split_ids]

        self.transform = transform
        print(f"Number of {partition} images: {len(self.file_names)}")
        self.images = {}
        counter = 0
        for i in range(len(self.file_names)):
            #image = cv2.imread(os.path.join(self.root, self.file_names[i]))
            #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = Image.open(os.path.join(self.root, self.file_names[i]))
            if counter%100==0:
                print("image count", counter)
            counter += 1
            if self.transform:
                image = self.transform(image)
            self.images[self.file_names[i]] = image

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, i):
        image = self.images[self.file_names[i]]
        boxes = self.boardBB[i]
        labels = self.boardLabels[i]
        boxes_tensor = torch.as_tensor(boxes, dtype=torch.float32)
        area = (boxes_tensor[:, 3] - boxes_tensor[:, 1]) * (boxes_tensor[:, 2] - boxes_tensor[:, 0])
        iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"]= torch.tensor([i])  # Required for evaluation
        target["area"] = area
        target["iscrowd"] = iscrowd
        # format the target to handle the Dataloader
        #target =
        return image,target

train_dataset = ChessDataset('..', 'train', data_aug)
valid_dataset = ChessDataset('..', 'valid', data_in)
test_dataset = ChessDataset('..', 'test', data_in)

batchsize = 3
def collate_fn(batch):
    """
    Custom collate function for object detection batches.
    Handles variable numbers of bounding boxes per image.
    """
    images = {}
    targets = []
    
    for img, target in batch:
        images[target['image_id']] = img
        processed_target = {
            'boxes': torch.as_tensor(target['boxes'], dtype=torch.float32),
            'labels': torch.as_tensor(target['labels'], dtype=torch.int64),
            'image_id': torch.as_tensor(target['image_id'], dtype=torch.int64),
            'area': torch.as_tensor(target['area'], dtype=torch.float32),
            'iscrowd': torch.as_tensor(target['iscrowd'], dtype=torch.int64)
        }
        targets.append(processed_target)
    
    # Stack images (they should all be the same size after transforms)
    images = torch.stack(list(images.values()), dim=0)
    
    return images, targets
train_dataloader = DataLoader(train_dataset, batch_size=batchsize, shuffle=True, num_workers=0,collate_fn=collate_fn)#,collate_fn=collate_fn
valid_dataloader = DataLoader(valid_dataset, batch_size=batchsize, shuffle=False, num_workers=0,collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=batchsize, shuffle=False, num_workers=0,collate_fn=collate_fn )

Number of train images: 1442
image count 0
image count 100
image count 200
image count 300
image count 400
image count 500
image count 600
image count 700
image count 800
image count 900
image count 1000
image count 1100
image count 1200
image count 1300
image count 1400
Number of valid images: 330
image count 0
image count 100
image count 200
image count 300
Number of test images: 306
image count 0
image count 100
image count 200
image count 300


# train loop

In [6]:
def epoch_iter(dataloader, model, loss_fn, optimizer=None, is_train=True):
    if is_train:
      assert optimizer is not None, "When training, please provide an optimizer."
      
    num_batches = len(dataloader)

    #if is_train:
    # so that the output is always the loss function easier to know the loss
    model.train() # put model in train mode
    #else:
    #  model.eval()

    total_loss = 0.0
    preds = []
    labels = []
    metrics = {'loss': 0, 'class_loss': 0, 'box_loss': 0, 'rpn_loss': 0}
    with torch.set_grad_enabled(is_train):
      for batch,(images, targets) in enumerate(tqdm(dataloader)):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        if is_train:
          optimizer.zero_grad()
          losses = sum(loss for loss in loss_dict.values())
          # Backpropagation
          
          losses.backward()
          optimizer.step()
        else:
          losses = sum(loss for loss in loss_dict.values())
        
        # Accumulate metrics
        metrics['loss'] += losses.item()
        metrics['class_loss'] += loss_dict['loss_classifier'].item()
        metrics['box_loss'] += loss_dict['loss_box_reg'].item()
        metrics['rpn_loss'] += loss_dict['loss_objectness'].item() + loss_dict['loss_rpn_box_reg'].item()
    return metrics['loss'] / num_batches, metrics['class_loss']/ num_batches

In [7]:
def train(model, model_name, num_epochs, train_dataloader, validation_dataloader, loss_fn, optimizer):
  train_history = {'loss': [], 'accuracy': []}
  val_history = {'loss': [], 'accuracy': []}
  best_val_loss = np.inf
  print("Start training...")
  for t in range(num_epochs):
      print(f"\nEpoch {t+1}")
      train_loss, train_acc = epoch_iter(train_dataloader, model, loss_fn, optimizer)
      print(f"Train loss: {train_loss:.3f} \t Train metric: {train_acc:.3f}")
      val_loss, val_acc = epoch_iter(validation_dataloader, model, loss_fn, is_train=False,optimizer=optimizer)
      print(f"Val loss: {val_loss:.3f} \t Val metric: {val_acc:.3f}")

      # save model when val loss improves
      if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_dict = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': t}
        torch.save(save_dict, model_name + '_best_model.pth')

      # save latest model
      save_dict = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': t}
      torch.save(save_dict, model_name + '_latest_model.pth')

      # save training history for plotting purposes
      train_history["loss"].append(train_loss)
      train_history["accuracy"].append(train_acc)

      val_history["loss"].append(val_loss)
      val_history["accuracy"].append(val_acc)
      
  print("Finished")
  return train_history, val_history

In [8]:
# Define loss function
#loss_fn = nn.CrossEntropyLoss()
loss_fn = nn.MSELoss()
# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
#optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
num_epochs = 50

# TODO - Train the model
model.to(device)
train_history, val_history = train(model, 'chess_model', num_epochs, train_dataloader, valid_dataloader, loss_fn, optimizer)

Start training...

Epoch 1


100%|██████████| 481/481 [03:06<00:00,  2.58it/s]


Train loss: 102.215 	 Train metric: 2.146


100%|██████████| 110/110 [00:16<00:00,  6.55it/s]


Val loss: 97.473 	 Val metric: 2.134

Epoch 2


100%|██████████| 481/481 [03:03<00:00,  2.62it/s]


Train loss: 101.893 	 Train metric: 2.145


100%|██████████| 110/110 [00:16<00:00,  6.50it/s]


Val loss: 91.860 	 Val metric: 2.131

Epoch 3


100%|██████████| 481/481 [03:03<00:00,  2.62it/s]


Train loss: 101.216 	 Train metric: 2.145


100%|██████████| 110/110 [00:16<00:00,  6.52it/s]


Val loss: 96.912 	 Val metric: 2.134

Epoch 4


100%|██████████| 481/481 [02:55<00:00,  2.75it/s]


Train loss: 101.771 	 Train metric: 2.146


100%|██████████| 110/110 [00:15<00:00,  6.92it/s]


Val loss: 91.624 	 Val metric: 2.133

Epoch 5


100%|██████████| 481/481 [02:53<00:00,  2.78it/s]


Train loss: 101.468 	 Train metric: 2.143


100%|██████████| 110/110 [00:15<00:00,  6.91it/s]


Val loss: 92.416 	 Val metric: 2.134

Epoch 6


100%|██████████| 481/481 [02:52<00:00,  2.78it/s]


Train loss: 101.391 	 Train metric: 2.144


100%|██████████| 110/110 [00:15<00:00,  6.93it/s]


Val loss: 91.554 	 Val metric: 2.139

Epoch 7


100%|██████████| 481/481 [02:53<00:00,  2.77it/s]


Train loss: 101.810 	 Train metric: 2.144


100%|██████████| 110/110 [00:15<00:00,  6.88it/s]


Val loss: 92.060 	 Val metric: 2.141

Epoch 8


100%|██████████| 481/481 [02:53<00:00,  2.77it/s]


Train loss: 101.107 	 Train metric: 2.144


100%|██████████| 110/110 [00:16<00:00,  6.82it/s]


Val loss: 90.607 	 Val metric: 2.134

Epoch 9


100%|██████████| 481/481 [02:53<00:00,  2.77it/s]


Train loss: 100.290 	 Train metric: 2.142


100%|██████████| 110/110 [00:16<00:00,  6.83it/s]


Val loss: 91.687 	 Val metric: 2.076

Epoch 10


100%|██████████| 481/481 [02:53<00:00,  2.78it/s]


Train loss: 99.476 	 Train metric: 2.142


100%|██████████| 110/110 [00:16<00:00,  6.87it/s]


Val loss: 89.535 	 Val metric: 2.138

Epoch 11


100%|██████████| 481/481 [02:53<00:00,  2.78it/s]


Train loss: 98.327 	 Train metric: 2.140


100%|██████████| 110/110 [00:16<00:00,  6.85it/s]


Val loss: 92.502 	 Val metric: 2.136

Epoch 12


100%|██████████| 481/481 [02:53<00:00,  2.77it/s]


Train loss: 97.694 	 Train metric: 2.142


100%|██████████| 110/110 [00:16<00:00,  6.85it/s]


Val loss: 88.839 	 Val metric: 2.132

Epoch 13


100%|██████████| 481/481 [02:53<00:00,  2.77it/s]


Train loss: 97.150 	 Train metric: 2.141


100%|██████████| 110/110 [00:16<00:00,  6.85it/s]


Val loss: 90.022 	 Val metric: 2.134

Epoch 14


100%|██████████| 481/481 [02:53<00:00,  2.78it/s]


Train loss: 95.906 	 Train metric: 2.142


100%|██████████| 110/110 [00:16<00:00,  6.85it/s]


Val loss: 88.772 	 Val metric: 2.137

Epoch 15


100%|██████████| 481/481 [03:03<00:00,  2.62it/s]


Train loss: 93.906 	 Train metric: 2.139


100%|██████████| 110/110 [00:17<00:00,  6.26it/s]


Val loss: 88.283 	 Val metric: 2.134

Epoch 16


100%|██████████| 481/481 [03:03<00:00,  2.63it/s]


Train loss: 91.291 	 Train metric: 2.143


100%|██████████| 110/110 [00:17<00:00,  6.35it/s]


Val loss: 88.789 	 Val metric: 2.137

Epoch 17


100%|██████████| 481/481 [03:03<00:00,  2.63it/s]


Train loss: 89.832 	 Train metric: 2.143


100%|██████████| 110/110 [00:17<00:00,  6.28it/s]


Val loss: 90.629 	 Val metric: 2.138

Epoch 18


100%|██████████| 481/481 [03:02<00:00,  2.64it/s]


Train loss: 88.190 	 Train metric: 2.144


100%|██████████| 110/110 [00:17<00:00,  6.29it/s]


Val loss: 84.468 	 Val metric: 2.134

Epoch 19


100%|██████████| 481/481 [03:03<00:00,  2.63it/s]


Train loss: 87.911 	 Train metric: 2.143


100%|██████████| 110/110 [00:17<00:00,  6.26it/s]


Val loss: 84.882 	 Val metric: 2.136

Epoch 20


100%|██████████| 481/481 [03:03<00:00,  2.63it/s]


Train loss: 88.062 	 Train metric: 2.143


100%|██████████| 110/110 [00:17<00:00,  6.27it/s]


Val loss: 85.267 	 Val metric: 2.131

Epoch 21


100%|██████████| 481/481 [03:03<00:00,  2.62it/s]


Train loss: 87.227 	 Train metric: 2.145


100%|██████████| 110/110 [00:17<00:00,  6.26it/s]


Val loss: 85.656 	 Val metric: 2.134

Epoch 22


100%|██████████| 481/481 [03:04<00:00,  2.61it/s]


Train loss: 86.606 	 Train metric: 2.143


100%|██████████| 110/110 [00:17<00:00,  6.27it/s]


Val loss: 86.431 	 Val metric: 2.131

Epoch 23


100%|██████████| 481/481 [03:02<00:00,  2.63it/s]


Train loss: 85.775 	 Train metric: 2.144


100%|██████████| 110/110 [00:17<00:00,  6.30it/s]


Val loss: 85.772 	 Val metric: 2.132

Epoch 24


100%|██████████| 481/481 [03:04<00:00,  2.60it/s]


Train loss: 85.062 	 Train metric: 2.143


100%|██████████| 110/110 [00:17<00:00,  6.29it/s]


Val loss: 86.357 	 Val metric: 2.136

Epoch 25


100%|██████████| 481/481 [03:03<00:00,  2.63it/s]


Train loss: 84.915 	 Train metric: 2.143


100%|██████████| 110/110 [00:17<00:00,  6.36it/s]


Val loss: 86.293 	 Val metric: 2.135

Epoch 26


100%|██████████| 481/481 [03:02<00:00,  2.63it/s]


Train loss: 84.546 	 Train metric: 2.144


100%|██████████| 110/110 [00:17<00:00,  6.35it/s]


Val loss: 87.573 	 Val metric: 2.132

Epoch 27


100%|██████████| 481/481 [03:03<00:00,  2.63it/s]


Train loss: 83.909 	 Train metric: 2.143


100%|██████████| 110/110 [00:17<00:00,  6.36it/s]


Val loss: 86.518 	 Val metric: 2.131

Epoch 28


100%|██████████| 481/481 [03:02<00:00,  2.64it/s]


Train loss: 82.748 	 Train metric: 2.144


100%|██████████| 110/110 [00:17<00:00,  6.27it/s]


Val loss: 89.460 	 Val metric: 2.133

Epoch 29


100%|██████████| 481/481 [03:03<00:00,  2.63it/s]


Train loss: 83.174 	 Train metric: 2.143


100%|██████████| 110/110 [00:17<00:00,  6.37it/s]


Val loss: 87.821 	 Val metric: 2.133

Epoch 30


100%|██████████| 481/481 [03:03<00:00,  2.63it/s]


Train loss: 82.171 	 Train metric: 2.143


100%|██████████| 110/110 [00:17<00:00,  6.36it/s]


Val loss: 85.769 	 Val metric: 2.135

Epoch 31


100%|██████████| 481/481 [02:57<00:00,  2.71it/s]


Train loss: 81.090 	 Train metric: 2.143


100%|██████████| 110/110 [00:15<00:00,  7.04it/s]


Val loss: 85.548 	 Val metric: 2.131

Epoch 32


100%|██████████| 481/481 [02:51<00:00,  2.80it/s]


Train loss: 81.700 	 Train metric: 2.143


100%|██████████| 110/110 [00:16<00:00,  6.51it/s]


Val loss: 85.702 	 Val metric: 2.133

Epoch 33


100%|██████████| 481/481 [03:01<00:00,  2.64it/s]


Train loss: 81.121 	 Train metric: 2.143


100%|██████████| 110/110 [00:16<00:00,  6.54it/s]


Val loss: 86.899 	 Val metric: 2.135

Epoch 34


100%|██████████| 481/481 [03:01<00:00,  2.65it/s]


Train loss: 80.802 	 Train metric: 2.143


100%|██████████| 110/110 [00:16<00:00,  6.60it/s]


Val loss: 85.104 	 Val metric: 2.131

Epoch 35


100%|██████████| 481/481 [03:01<00:00,  2.65it/s]


Train loss: 78.871 	 Train metric: 2.144


100%|██████████| 110/110 [00:16<00:00,  6.55it/s]


Val loss: 86.939 	 Val metric: 2.134

Epoch 36


100%|██████████| 481/481 [03:02<00:00,  2.64it/s]


Train loss: 79.564 	 Train metric: 2.142


100%|██████████| 110/110 [00:16<00:00,  6.54it/s]


Val loss: 87.836 	 Val metric: 2.137

Epoch 37


100%|██████████| 481/481 [02:59<00:00,  2.68it/s]


Train loss: 78.719 	 Train metric: 2.144


100%|██████████| 110/110 [00:15<00:00,  7.01it/s]


Val loss: 85.864 	 Val metric: 2.134

Epoch 38


100%|██████████| 481/481 [02:50<00:00,  2.83it/s]


Train loss: 77.791 	 Train metric: 2.144


100%|██████████| 110/110 [00:15<00:00,  6.98it/s]


Val loss: 85.936 	 Val metric: 2.136

Epoch 39


100%|██████████| 481/481 [02:50<00:00,  2.83it/s]


Train loss: 77.373 	 Train metric: 2.144


100%|██████████| 110/110 [00:15<00:00,  6.98it/s]


Val loss: 86.818 	 Val metric: 2.137

Epoch 40


100%|██████████| 481/481 [02:52<00:00,  2.79it/s]


Train loss: 76.604 	 Train metric: 2.143


100%|██████████| 110/110 [00:15<00:00,  6.88it/s]


Val loss: 87.497 	 Val metric: 2.131

Epoch 41


100%|██████████| 481/481 [02:58<00:00,  2.70it/s]


Train loss: 76.251 	 Train metric: 2.144


100%|██████████| 110/110 [00:15<00:00,  6.93it/s]


Val loss: 87.396 	 Val metric: 2.130

Epoch 42


100%|██████████| 481/481 [02:53<00:00,  2.78it/s]


Train loss: 75.253 	 Train metric: 2.144


100%|██████████| 110/110 [00:15<00:00,  6.92it/s]


Val loss: 86.799 	 Val metric: 2.135

Epoch 43


100%|██████████| 481/481 [02:56<00:00,  2.73it/s]


Train loss: 75.283 	 Train metric: 2.144


100%|██████████| 110/110 [00:16<00:00,  6.63it/s]


Val loss: 87.471 	 Val metric: 2.136

Epoch 44


100%|██████████| 481/481 [02:58<00:00,  2.69it/s]


Train loss: 74.502 	 Train metric: 2.144


100%|██████████| 110/110 [00:16<00:00,  6.83it/s]


Val loss: 87.057 	 Val metric: 2.132

Epoch 45


100%|██████████| 481/481 [03:00<00:00,  2.67it/s]


Train loss: 73.630 	 Train metric: 2.144


100%|██████████| 110/110 [00:16<00:00,  6.54it/s]


Val loss: 86.110 	 Val metric: 2.135

Epoch 46


100%|██████████| 481/481 [02:57<00:00,  2.70it/s]


Train loss: 71.963 	 Train metric: 2.143


100%|██████████| 110/110 [00:16<00:00,  6.86it/s]


Val loss: 87.027 	 Val metric: 2.135

Epoch 47


100%|██████████| 481/481 [02:52<00:00,  2.78it/s]


Train loss: 71.968 	 Train metric: 2.143


100%|██████████| 110/110 [00:16<00:00,  6.87it/s]


Val loss: 86.277 	 Val metric: 2.134

Epoch 48


100%|██████████| 481/481 [02:53<00:00,  2.78it/s]


Train loss: 72.162 	 Train metric: 2.143


100%|██████████| 110/110 [00:16<00:00,  6.87it/s]


Val loss: 88.045 	 Val metric: 2.138

Epoch 49


100%|██████████| 481/481 [02:53<00:00,  2.78it/s]


Train loss: 70.460 	 Train metric: 2.143


100%|██████████| 110/110 [00:15<00:00,  6.88it/s]


Val loss: 87.220 	 Val metric: 2.134

Epoch 50


100%|██████████| 481/481 [02:52<00:00,  2.78it/s]


Train loss: 69.527 	 Train metric: 2.144


100%|██████████| 110/110 [00:15<00:00,  6.89it/s]


Val loss: 87.755 	 Val metric: 2.135
Finished


In [9]:

import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
from torchvision.ops import box_iou
def calculate_average_iou(dataloader, model, device, score_threshold=0.5):
    """
    Calculate average IoU between predictions and ground truth boxes
    
    Args:
        dataloader: PyTorch DataLoader yielding (images, targets)
        model: Your detection model
        device: torch.device
        score_threshold: Minimum confidence score to consider a prediction
        
    Returns:
        mean_iou: Average IoU across all matched predictions
        matched_ratio: Percentage of predictions that matched with ground truth
    """
    model.eval()
    total_iou = 0.0
    total_matches = 0
    total_predictions = 0
    
    with torch.no_grad():
        for images, targets in tqdm(dataloader, desc="Calculating IoU"):
            images = list(img.to(device) for img in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            predictions = model(images)
            
            for pred, target in zip(predictions, targets):
                # Filter predictions by score threshold
                keep = pred['scores'] >= score_threshold
                pred_boxes = pred['boxes'][keep]
                gt_boxes = target['boxes']
                
                if len(pred_boxes) == 0 or len(gt_boxes) == 0:
                    continue
                
                # Calculate IoU matrix
                iou_matrix = box_iou(pred_boxes, gt_boxes)
                
                # Find best matches (one prediction can only match one ground truth)
                pred_matches = iou_matrix.argmax(dim=1)  # For each pred, find best gt
                best_ious = iou_matrix[torch.arange(len(pred_matches)), pred_matches]
                
                # Filter matches where IoU > 0
                valid_matches = best_ious > 0
                total_iou += best_ious[valid_matches].sum().item()
                total_matches += valid_matches.sum().item()
                total_predictions += len(pred_boxes)
    
    # Calculate metrics
    mean_iou = total_iou / total_matches if total_matches > 0 else 0.0
    matched_ratio = total_matches / total_predictions if total_predictions > 0 else 0.0
    
    return mean_iou, matched_ratio
mean_iou, matched_ratio = calculate_average_iou(
    dataloader=test_dataloader,
    model=model,
    device=device,
    score_threshold=0.5  # Only consider predictions with confidence > 50%
)

print(f"Average IoU: {mean_iou:.4f}")
print(f"Matched Ratio: {matched_ratio:.2%} (percentage of predictions that matched ground truth)")

Calculating IoU: 100%|██████████| 102/102 [00:14<00:00,  6.99it/s]

Average IoU: 0.0000
Matched Ratio: 0.00% (percentage of predictions that matched ground truth)



