In [1]:
import os
import time

import torchvision
from torchvision.transforms import v2, PILToTensor
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F

import torch
from torch.utils.data import Dataset, random_split
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import cv2
import numpy as np

import xml.etree.ElementTree as ET
from PIL import Image



In [2]:
def show_image(image):
    plt.imshow(image.permute(1,2,0))
    plt.axis('off')
    plt.show()

    
def show_image_with_box(image,boxes):

    image_draw = image.permute(1,2,0).clone().cpu().numpy()
    boxes_np = [box.tolist() for box in boxes]

    for box in boxes_np:
        x1, y1, x2, y2 = box
        cv2.rectangle(image_draw, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 1)
    
    plt.imshow(image_draw)
    plt.axis('off')
    plt.show()

def get_images_labeled(labels_path):
    return os.listdir(labels_path)

def change_image_to_label(image_path):
    new_path = image_path.replace('images','labels')
    if 'jpg' in new_path:
        new_path = new_path.replace('jpg','xml')
    elif 'png' in new_path:
        new_path = new_path.replace('png','xml')
    
    return new_path

def getBoxes(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()

    boxes = []

    for elem in root:
        if elem.tag == 'object':
            for obj_elem in elem:
                if obj_elem.tag == 'bndbox':
                    line = []
                    for bbox_elem in obj_elem:
                        value = float(bbox_elem.text)
                        line.append(value)
                    boxes.append(line)
    return boxes

def getLabels(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()

    labels = []

    for elem in root:
        if elem.tag == 'object':
            for obj_elem in elem:
                if obj_elem.tag == 'name':
                    labels.append(1)
    return labels

def plot(imgs, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0])
    _, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        for col_idx, img in enumerate(row):
            boxes = None
            masks = None
            if isinstance(img, tuple):
                img, target = img
                if isinstance(target, dict):
                    boxes = target.get("boxes")
                    masks = target.get("masks")
                elif isinstance(target, tv_tensors.BoundingBoxes):
                    boxes = target
                else:
                    raise ValueError(f"Unexpected target type: {type(target)}")
            img = F.to_image(img)
            if img.dtype.is_floating_point and img.min() < 0:
                # Poor man's re-normalization for the colors to be OK-ish. This
                # is useful for images coming out of Normalize()
                img -= img.min()
                img /= img.max()

            img = F.to_dtype(img, torch.uint8, scale=True)
            if boxes is not None:
                img = draw_bounding_boxes(img, boxes, colors="yellow", width=3)
            if masks is not None:
                img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)

            ax = axs[row_idx, col_idx]
            ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

In [3]:
class PoolDatasetV2(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        self.labeled_images = get_images_labeled(self.root_dir.replace('images','labels'))
        self.retrieve_data()

    def retrieve_data(self):
        self.images = []
        self.boxes = []
        self.labels = []
        
        for file in os.listdir(self.root_dir):
            
            full_image_path = self.root_dir + "/" + file
            label_name = change_image_to_label(file)


            # Deal with the image
            img = Image.open(full_image_path).convert('RGB')

            # Transform to tensor
            if "PIL" in str(type(img)):
                img = PILToTensor()(img)
            else:
                img = torch.Tensor(img)

            self.images.append(img)


            # Deal with the label
            if label_name in self.labeled_images:
                boxes = getBoxes(self.root_dir.replace('images','labels') + "/" + label_name)
                labels = getLabels(self.root_dir.replace('images','labels') + "/" + label_name)
                self.labels.append(labels)
                #self.boxes.append([torch.Tensor(box) for box in boxes])
                self.boxes.append(boxes)

            else:
                self.boxes.append([torch.Tensor([0.0, 0.0, 224.0, 224.0])])
                self.labels.append([0])

    def __len__(self):
        return len(self.images)
    

    def __getitem__(self, idx):
        target = {}
        img = self.images[idx]
        boxes = self.boxes[idx]
        labels = self.labels[idx]

        old_width = img.shape[1]

        if self.transform:
            img = torch.Tensor(self.transform(img))
            
            factor = old_width / img.shape[1]

            if len(boxes) > 0:
                #print(boxes)
                boxes = list(map(lambda l: [v / factor for v in l], boxes))
                #print(boxes)
                boxes = torch.Tensor(boxes)
                
                #max_boxes = max(len(boxes), 1)  # Ensure at least 1 box
                #pad_boxes = F.pad(torch.stack(boxes), (0, 0, 0, max_boxes - len(boxes)), value=-1)
                #boxes = pad_boxes.unbind(0)
                labels = torch.IntTensor(labels)
        else:
            boxes = torch.Tensor(boxes)
            labels = torch.IntTensor(labels, dtype=torch.int64)
        
        target['boxes'] = boxes
        target['labels'] = labels

        return img, target
    
    def split_Data(self, n_test=0.33):
        test_size = round(n_test * len(self.images))
        train_size = len(self.images) - test_size

        return random_split(self, [train_size, test_size])
    
    def collate(self,batch):
        images = list()
        boxes = list()

        for b in batch:
            images.append(b[0])
            boxes.append(b[1])
        
        images = torch.stack(images, dim=0)
        return images,boxes

---------------------------------------------------------

In [4]:
def show_image_with_box(image,boxes):

    image_draw = image.permute(1,2,0).clone().cpu().numpy()
    boxes_np = [box.tolist() for box in boxes]
    for box in boxes_np:
        x1, y1, x2, y2 = box
        cv2.rectangle(image_draw, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 1)
    
    plt.imshow(image_draw)
    plt.axis('off')
    plt.show()

In [5]:
transforms = v2.Compose([
    v2.Resize((224, 224)),
    v2.ToImage(),
    #v2.RandomHorizontalFlip(p=1),
    v2.ToDtype(torch.float32, scale=True),
    #v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

ROOT_DIR = "/kaggle/input/dataset/subset/images"

In [6]:
ds = PoolDatasetV2(ROOT_DIR,transforms)

In [7]:
train, test = ds.split_Data(0.33)
len(train)

34

In [8]:
train[1]



(tensor([[[0.3255, 0.2235, 0.3412,  ..., 0.0706, 0.0863, 0.0941],
          [0.3294, 0.2667, 0.4235,  ..., 0.0784, 0.0902, 0.0784],
          [0.3333, 0.3686, 0.3451,  ..., 0.0667, 0.0824, 0.0549],
          ...,
          [0.0627, 0.0627, 0.0588,  ..., 0.2784, 0.2431, 0.1882],
          [0.0627, 0.0627, 0.0667,  ..., 0.2784, 0.2745, 0.2471],
          [0.0471, 0.0549, 0.0706,  ..., 0.2588, 0.3059, 0.3137]],
 
         [[0.3294, 0.2275, 0.3490,  ..., 0.0980, 0.1059, 0.1137],
          [0.3333, 0.2745, 0.4314,  ..., 0.1059, 0.1176, 0.1059],
          [0.3412, 0.3765, 0.3529,  ..., 0.0941, 0.1098, 0.0902],
          ...,
          [0.1059, 0.1059, 0.1020,  ..., 0.2824, 0.2471, 0.1922],
          [0.1059, 0.1059, 0.1020,  ..., 0.2784, 0.2706, 0.2431],
          [0.0902, 0.0980, 0.1059,  ..., 0.2588, 0.2980, 0.3059]],
 
         [[0.2667, 0.1647, 0.2980,  ..., 0.1294, 0.1294, 0.1373],
          [0.2706, 0.2196, 0.3804,  ..., 0.1373, 0.1490, 0.1294],
          [0.2863, 0.3216, 0.3020,  ...,

In [9]:
# def collate_fn(batch):
#     images = [item[0] for item in batch]
#     boxes = [item[1] for item in batch]

#     max_size_labels = max([len(box) for box in boxes])

#     for box in boxes:
#         if len(box) < max_size_labels:
#             toadd = torch.Tensor([0.0, 0.0, 0.0, 0.0])
#             box.extend([toadd for i in range(max_size_labels - len(box))])

#     return images, boxes

def collate_fn(batch):
    return batch

In [10]:
train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collate_fn, pin_memory=True if torch.cuda.is_available() else False)
test_loader = DataLoader(test, batch_size=16, shuffle=True, collate_fn=collate_fn, pin_memory=True if torch.cuda.is_available() else False)

# Training

In [50]:
model = fasterrcnn_resnet50_fpn(pretrained=False)
num_classes = 2
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)



In [51]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [52]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
num_epochs = 10

In [53]:
model.to(device)
for epoch in range(num_epochs):
    epoch_loss = 0
    start = time.time()
    for data in train_loader:
        # print(data)
        imgs = []
        targets = []
        for d in data:
            imgs.append(d[0].to(device))
            targ = {}
            targ['boxes'] = d[1]['boxes'].to(device)
            targ['labels'] = d[1]['labels'].to(torch.int64).to(device)
            targets.append(targ)

        loss_dict = model(imgs, targets)
        loss = sum(loss for loss in loss_dict.values())
        epoch_loss += loss.cpu().detach().numpy()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    total_time = time.time() - start
    print(f'Epoch: {epoch}, Loss: {epoch_loss} -> {round(total_time)}s')

Epoch: 0, Loss: 2.8054462671279907 -> 6s
Epoch: 1, Loss: 1.719754934310913 -> 6s
Epoch: 2, Loss: 1.315606951713562 -> 6s
Epoch: 3, Loss: 1.0491372048854828 -> 6s
Epoch: 4, Loss: 1.1060037016868591 -> 6s
Epoch: 5, Loss: 0.7835223078727722 -> 6s
Epoch: 6, Loss: 1.0069560110569 -> 6s
Epoch: 7, Loss: 0.7276667952537537 -> 6s
Epoch: 8, Loss: 0.7058354914188385 -> 6s
Epoch: 9, Loss: 0.7192432731389999 -> 6s


-----------------

# Evaluate

In [85]:
import numpy as np

def compute_iou(box1, box2):
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])
    
    intersection = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
    area_box1 = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1)
    area_box2 = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1)
    
    union = area_box1 + area_box2 - intersection
    
    iou = intersection / union
    return iou

def evaluate_detections(pred_boxes, gt_boxes, iou_threshold=0.5):
    TP, FP, FN = 0, 0, 0
    matched_gt_boxes = set()
    
    for pred_box in pred_boxes:
        matched = False
        for gt_box in gt_boxes:
            if compute_iou(pred_box, gt_box) >= iou_threshold:
                if gt_box not in matched_gt_boxes:
                    TP += 1
                    matched_gt_boxes.add(gt_box)
                    matched = True
                    break
        if not matched:
            FP += 1
    
    FN = len(gt_boxes) - len(matched_gt_boxes)
    return TP, FP, FN

def calculate_metrics(TP, FP, FN):
    precision = TP / (TP + FP) if (TP + FP) > 0 else 0
    recall = TP / (TP + FN) if (TP + FN) > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    return precision, recall, f1_score

def calculate_accuracy(TP, FP, FN):
    total_predictions = TP + FP
    total_ground_truths = TP + FN
    accuracy = TP / (total_predictions + total_ground_truths - TP) if (total_predictions + total_ground_truths - TP) > 0 else 0
    return accuracy

def evaluate_model(model, dataloader, device, iou_threshold=0.5, min_confidence = 0.7):
    model.eval()
    TP, FP, FN = 0, 0, 0

    with torch.no_grad():
        for data in dataloader:
            imgs = []
            targets = []
            for d in data:
                imgs.append(d[0].to(device))
                targ = {}
                targ['boxes'] = d[1]['boxes']
                targ['labels'] = d[1]['labels']
                targets.append(targ)

            outputs = model(imgs)
                       
            
            for i in range(len(imgs)):
                
                scores = outputs[i]['scores']
                high_score_indices = scores >= min_confidence
                
                pred_boxes = outputs[i]['boxes'][high_score_indices]
                gt_boxes = targets[i]['boxes']

                tp, fp, fn = evaluate_detections(pred_boxes, gt_boxes, iou_threshold)
                TP += tp
                FP += fp
                FN += fn

    precision, recall, f1_score = calculate_metrics(TP, FP, FN)
    accuracy = calculate_accuracy(TP, FP, FN)
    
    return precision, recall, f1_score, accuracy

    

In [86]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

precision, recall, f1_score, accuracy = evaluate_model(model, train_loader, device)

print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1_score:.4f}")
print(f"Accuracy: {accuracy:.4f}")

Precision: 0.6790
Recall: 0.9649
F1 Score: 0.7971
Accuracy: 0.6627
