In [1]:
!pip install pycocotools





# Libraries

In [2]:
import os
import torch
from torch.utils.data import DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
from torchvision.transforms import transforms
from torchvision.ops import box_iou
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.strategies import DDPStrategy
from pycocotools.coco import COCO
from PIL import Image
import numpy as np
from pytorch_lightning.callbacks import RichProgressBar

# Path and Directory

In [3]:
root_dir = '/kaggle/working/'
train_dir = '/kaggle/input/mycqadataset/train'
test_dir = '/kaggle/input/mycqadataset/test'
val_dir = '/kaggle/input/mycqadataset/valid'
train_labels_file = '/kaggle/input/mycqadataset/train/_annotations.coco.json'
val_labels_file = '/kaggle/input/mycqadataset/valid/_annotations.coco.json'
LR = 0.0001
WEIGHT_DECAY = 0.0001
NUM_EPOCHS = 30
BATCH_SIZE = 16
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Custom Dataset Classes

In [4]:
class YoloDataset(torch.utils.data.Dataset):
    def __init__(self, images_dir, labels_dir, transform=None):
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(images_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load image
        image_file = self.image_files[idx]
        img_path = os.path.join(self.images_dir, image_file)
        img = Image.open(img_path).convert("RGB")

        # Load corresponding label file
        label_file = os.path.splitext(image_file)[0] + ".txt"
        label_path = os.path.join(self.labels_dir, label_file)

        boxes = []
        labels = []

        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    class_id = int(parts[0])
                    x_center, y_center, width, height = map(float, parts[1:])
                    labels.append(class_id)

                    # Convert YOLO format back to bounding box coordinates
                    xmin = x_center - width / 2
                    ymin = y_center - height / 2
                    xmax = x_center + width / 2
                    ymax = y_center + height / 2

                    boxes.append([xmin, ymin, xmax, ymax])

        if not boxes:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
        else:
            boxes = torch.tensor(boxes, dtype=torch.float32)
            labels = torch.tensor(labels, dtype=torch.int64)
        if self.transform:
            img = self.transform(img)
        target = {"boxes": boxes, "labels": labels}

        return img, target

In [5]:
class CocoDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, annFile, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.coco = COCO(annFile)
        self.ids = self.coco.getImgIds()

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)
        img_info = self.coco.loadImgs(img_id)[0]
        path = os.path.join(self.root_dir, img_info['file_name'])
        img = Image.open(path).convert("RGB")

        boxes = []
        labels = []
        for ann in anns:
            x, y, w, h = ann['bbox']
            if w > 0 and h > 0:
                xmin, ymin, xmax, ymax = x, y, x + w, y + h
                boxes.append([xmin, ymin, xmax, ymax])
                labels.append(ann['category_id'])

        if not boxes:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
        else:
            boxes = torch.tensor(boxes, dtype=torch.float32)
            labels = torch.tensor(labels, dtype=torch.int64)

        if self.transform:
            img = self.transform(img)

        target = {"boxes": boxes, "labels": labels}
        return img, target

# Preparing Data

In [6]:
# Transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Datasets and Dataloaders
# train_dataset = CocoDataset(root_dir=train_dir, annFile=train_labels_file, transform=transform)
# val_dataset = CocoDataset(root_dir=val_dir, annFile=val_labels_file, transform=transform)
train_dataset = YoloDataset(
    images_dir="/kaggle/input/chart-detection-v4/doclaynet_yolo_dataset_v4/images/train",
    labels_dir="/kaggle/input/chart-detection-v4/doclaynet_yolo_dataset_v4/labels/train",
    transform=transform
)

val_dataset = YoloDataset(
    images_dir="/kaggle/input/chart-detection-v4/doclaynet_yolo_dataset_v4/images/val",
    labels_dir="/kaggle/input/chart-detection-v4/doclaynet_yolo_dataset_v4/labels/val",
    transform=transform
)
def custom_collate_fn(batch):
    images, targets = zip(*batch)
    return list(images), list(targets)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=3, shuffle=True, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=3, shuffle=False, collate_fn=custom_collate_fn)

# FasterRCNN Lightning

In [7]:
class FasterRCNNLightning(LightningModule):
    def __init__(self):
        super(FasterRCNNLightning, self).__init__()
        self.model = fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
        num_classes = 2  
        in_features = self.model.roi_heads.box_predictor.cls_score.in_features

        self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
        self.lr = LR
        self.weight_decay = WEIGHT_DECAY

    def forward(self, images):
        return self.model(images)

    def training_step(self, batch, batch_idx):
        images, targets = batch
        images = list(img.to(self.device) for img in images)
        targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
        loss_dict = self.model(images, targets)
        loss = sum(loss for loss in loss_dict.values())
        self.log('train_loss', loss, prog_bar=True, batch_size=len(images))
        return loss

    def validation_step(self, batch, batch_idx):
        images, targets = batch
        images = list(img.to(self.device) for img in images)
        targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]

        predictions = self.model(images)
        TP, FP, FN = 0, 0, 0
        iou_threshold = 0.5
        iou_scores = []

        for target, prediction in zip(targets, predictions):
            target_boxes = target['boxes']
            target_labels = target['labels']
            predicted_boxes = prediction['boxes']
            predicted_labels = prediction['labels']

            if target_boxes.shape[0] == 0 or predicted_boxes.shape[0] == 0:
                FN += len(target_boxes)
                FP += len(predicted_boxes)
                continue

            iou = box_iou(predicted_boxes, target_boxes)
            matched_gt = set()

            for i in range(iou.size(0)):
                max_iou, idx = iou[i].max(0)
                if max_iou > iou_threshold and idx.item() not in matched_gt:
                    TP += 1
                    matched_gt.add(idx.item())
                else:
                    FP += 1

            FN += len(target_boxes) - len(matched_gt)
            iou_scores.append(iou.mean().item())

        precision = TP / (TP + FP) if (TP + FP) > 0 else 0
        recall = TP / (TP + FN) if (TP + FN) > 0 else 0
        f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        mean_iou = np.mean(iou_scores) if iou_scores else 0

        self.log('val_precision', precision, prog_bar=True, sync_dist=True)
        self.log('val_recall', recall, prog_bar=True, sync_dist=True)
        self.log('val_f1_score', f1_score, prog_bar=True, sync_dist=True)
        self.log('val_mean_iou', mean_iou, prog_bar=True, sync_dist=True)

        map_scores = {}
        for threshold in [0.5, 0.75, 0.9]:
            TP, FP, FN = 0, 0, 0
            for target, prediction in zip(targets, predictions):
                target_boxes = target['boxes']
                predicted_boxes = prediction['boxes']

                if target_boxes.shape[0] == 0 or predicted_boxes.shape[0] == 0:
                    FN += len(target_boxes)
                    FP += len(predicted_boxes)
                    continue

                iou = box_iou(predicted_boxes, target_boxes)
                matched_gt = set()

                for i in range(iou.size(0)):
                    max_iou, idx = iou[i].max(0)
                    if max_iou > threshold and idx.item() not in matched_gt:
                        TP += 1
                        matched_gt.add(idx.item())
                    else:
                        FP += 1

                FN += len(target_boxes) - len(matched_gt)

            precision = TP / (TP + FP) if (TP + FP) > 0 else 0
            recall = TP / (TP + FN) if (TP + FN) > 0 else 0
            map_scores[f'mAP@{threshold:.2f}'] = precision * recall / (precision + recall) if (precision + recall) > 0 else 0

        for key, value in map_scores.items():
            self.log(key, value, prog_bar=True, sync_dist=True)

        return predictions

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        return optimizer


# Callbacks

In [8]:
checkpoint_callback = ModelCheckpoint(
    monitor='train_loss',
    mode='min',
    filename='retinanet-{epoch:02d}-{train_loss:.2f}',
    save_top_k=3
)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
# rich_progress_bar = RichProgressBar()

# Trainer

In [9]:
trainer = Trainer(
    max_epochs=NUM_EPOCHS,
    accelerator='gpu',
    devices=torch.cuda.device_count(),
    strategy="ddp_notebook",
    callbacks=[checkpoint_callback, lr_monitor],
    precision='16-mixed'
)

# Train

In [10]:
# Train
model = FasterRCNNLightning()
trainer.fit(model, train_loader, val_loader)

Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth


  0%|          | 0.00/160M [00:00<?, ?B/s]

  9%|▉         | 14.4M/160M [00:00<00:01, 150MB/s]

 24%|██▎       | 37.8M/160M [00:00<00:00, 206MB/s]

 38%|███▊      | 61.4M/160M [00:00<00:00, 225MB/s]

 53%|█████▎    | 84.9M/160M [00:00<00:00, 233MB/s]

 68%|██████▊   | 108M/160M [00:00<00:00, 237MB/s] 

 82%|████████▏ | 131M/160M [00:00<00:00, 239MB/s]

 97%|█████████▋| 155M/160M [00:00<00:00, 241MB/s]

100%|██████████| 160M/160M [00:00<00:00, 232MB/s]




  self.pid = os.fork()


/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 3. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]