In [1]:
import os
import pprint
from dataclasses import dataclass

import numpy as np
import pytorch_lightning as pl
import scipy.io as io
import torch
import torch.utils.data as td
import torchvision.transforms as transforms
from PIL import Image
from sklearn.model_selection import train_test_split
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision.models.detection import fasterrcnn_mobilenet_v3_large_fpn

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [2]:
@dataclass
class BBox:
    x: int
    y: int
    w: int
    h: int

    @staticmethod
    def parse_from_mat(label):
        xs, ys = label[1][0], label[3][0]

        x, y = min(xs), min(ys)
        w, h = max(xs) - x, max(ys) - y

        return BBox(x, y, w, h)

    def to_tensor(self):
        return torch.Tensor([self.x, self.y, self.x + self.w, self.y + self.h])
    
    def is_point(self) -> bool:
        return self.w == 0 or self.h == 0


def parse_boxes(annotation):
    boxes = []
 
    for label in annotation:
        box = BBox.parse_from_mat(label)

        if box.is_point():
            continue

        boxes.append(box.to_tensor())

    return torch.stack(boxes)

In [3]:
class Subset(td.Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices
    
    def __getitem__(self, index):
        return self.dataset[self.indices[index]]

    def __len__(self):
        return len(self.indices)
    
class DetectionDataset(td.Dataset):
    def __init__(
        self, 
        images_root, 
        labels_root = None, 
        resize=(300, 300)
    ):
        self.images_root = images_root
        self.labels_root = labels_root
        self.is_testing  = labels_root is None
        self.resize      = resize
    
        self.files = [f.split(".")[0] for f in os.listdir(images_root)]
        self.files = [f for f in self.files if len(f) != 0]
        
        self.transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(self.resize),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        
    def __getitem__(self, index):
        name = self.files[index]
        
        if os.path.exists(f"{self.images_root}/{name}.jpg"):
            image = Image.open(f"{self.images_root}/{name}.jpg")
        else:
            image = Image.open(f"{self.images_root}/{name}.JPG")
            
        w, h  = image.width, image.height
        image = self.transforms(image)
        
        if self.is_testing:
            return image
        
        annotation = io.loadmat(f"{self.labels_root}/gt_{name}.mat")["gt"]
        old_dims   = torch.FloatTensor([w, h, w, h])
        boxes      = parse_boxes(annotation)
        boxes      = boxes / old_dims
        boxes      = boxes * torch.FloatTensor([self.resize[1], self.resize[0], self.resize[1], self.resize[0]])
        labels     = torch.LongTensor([1] * len(boxes))
            
        return image, boxes, labels
    
    def __len__(self):
        return len(self.files)

In [4]:
def collate_fn(batch):
    images, bboxes, labels = [], [], []
    
    for image, box, label in batch:
        images.append(image)
        bboxes.append(box)
        labels.append(label)
    
    return torch.stack(images, dim=0), bboxes, labels
    

class DataModule(pl.LightningDataModule):
    def __init__(
            self,
            dataset: DetectionDataset,
            collate_fn = collate_fn,
            batch_size: int   = 32,
            valid_size: float = 0.2
    ):
        super().__init__()
        
        self.dataset    = dataset
        self.batch_sz   = batch_size
        self.collate_fn = collate_fn
        
        train, valid = self._train_valid_split(self.dataset, valid_size=valid_size)

        self.train = train
        self.valid = valid

    def _train_valid_split(self, dataset, valid_size: float):
        train_indices, valid_indices = train_test_split(np.arange(len(dataset)), test_size=valid_size)
        train = Subset(dataset, train_indices)
        valid = Subset(dataset, valid_indices)
        return train, valid

    def train_dataloader(self):
        return td.DataLoader(self.train, batch_size=self.batch_sz, shuffle=True, collate_fn=self.collate_fn)

    def val_dataloader(self):
        return td.DataLoader(self.valid, batch_size=self.batch_sz, collate_fn=self.collate_fn)

In [5]:
class Detector(pl.LightningModule):
    def __init__(self, torch_model):
        super().__init__()
        
        self.model = torch_model
        self.valid_map = MeanAveragePrecision(iou_thresholds=[0.5, 0.75, 0.9])
        
    def training_step(self, batch, _):
        images, bboxes, labels = batch
         
        targets = [
            {"boxes": boxes, "labels": label} for boxes, label in zip(bboxes, labels)
        ]
        
        output = self.model(images, targets)
        loss   = output["loss_classifier"] + output["loss_box_reg"] + output["loss_objectness"] + output["loss_rpn_box_reg"]
        loss   = loss / 4
        
        return loss

    def validation_step(self, batch, *args, **kwargs):
        images, bboxes, labels = batch
        
        targets = [
            {"boxes": boxes, "labels": label} for boxes, label in zip(bboxes, labels)
        ]
        
        outputs = self.model(images)

        self.valid_map.update(outputs, targets)

    def validation_epoch_end(self, outputs):
        valid_map = self.valid_map.compute()
        
        self.log("map", valid_map["map"].item())
        
        pprint.pprint(valid_map)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=5e-5)

In [6]:
train = DetectionDataset(
    images_root="../input/text-recognition-total-text-dataset/totaltext/Images/Train",
    labels_root="../input/text-recognition-total-text-dataset/TT_new_train_GT/Train"
)

data_module = DataModule(train, valid_size=0.1)
model       = fasterrcnn_mobilenet_v3_large_fpn(pretrained=True, num_classes=2)
detector    = Detector(model)

callbacks = [
    pl.callbacks.EarlyStopping(
        monitor="map",
        patience=2,
        mode="max"
    ),
    pl.callbacks.ModelCheckpoint(
        dirpath="checkpoints",
        monitor="map",
        mode="max",
        filename="fasterrcnn_mobilenet_v3_large_fpn-{epoch}-{map:.4f}"
    )
]

trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    max_epochs=10,
    callbacks=callbacks
)

Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_large-8738ca79.pth


  0%|          | 0.00/21.1M [00:00<?, ?B/s]

In [7]:
trainer.fit(detector, data_module)

Sanity Checking: 0it [00:00, ?it/s]

{'map': tensor(0.0001),
 'map_50': tensor(0.0003),
 'map_75': tensor(0.),
 'map_large': tensor(0.0188),
 'map_medium': tensor(0.0001),
 'map_per_class': tensor(-1.),
 'map_small': tensor(7.2535e-06),
 'mar_1': tensor(0.),
 'mar_10': tensor(0.0013),
 'mar_100': tensor(0.0108),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.0667),
 'mar_medium': tensor(0.0207),
 'mar_small': tensor(0.0021)}




Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

{'map': tensor(0.0003),
 'map_50': tensor(0.0010),
 'map_75': tensor(4.6358e-06),
 'map_large': tensor(0.0091),
 'map_medium': tensor(0.0006),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.0001),
 'mar_1': tensor(0.0004),
 'mar_10': tensor(0.0042),
 'mar_100': tensor(0.0333),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.1091),
 'mar_medium': tensor(0.0561),
 'mar_small': tensor(0.0197)}


Validation: 0it [00:00, ?it/s]

{'map': tensor(0.0009),
 'map_50': tensor(0.0028),
 'map_75': tensor(1.0281e-05),
 'map_large': tensor(0.0096),
 'map_medium': tensor(0.0022),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.0003),
 'mar_1': tensor(0.0024),
 'mar_10': tensor(0.0134),
 'mar_100': tensor(0.0532),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.1407),
 'mar_medium': tensor(0.1007),
 'mar_small': tensor(0.0293)}


Validation: 0it [00:00, ?it/s]

{'map': tensor(0.0056),
 'map_50': tensor(0.0156),
 'map_75': tensor(0.0011),
 'map_large': tensor(0.0181),
 'map_medium': tensor(0.0095),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.0038),
 'mar_1': tensor(0.0045),
 'mar_10': tensor(0.0205),
 'mar_100': tensor(0.0657),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.1653),
 'mar_medium': tensor(0.1311),
 'mar_small': tensor(0.0340)}


Validation: 0it [00:00, ?it/s]

{'map': tensor(0.0080),
 'map_50': tensor(0.0222),
 'map_75': tensor(0.0017),
 'map_large': tensor(0.0386),
 'map_medium': tensor(0.0147),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.0041),
 'mar_1': tensor(0.0063),
 'mar_10': tensor(0.0283),
 'mar_100': tensor(0.0789),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.1854),
 'mar_medium': tensor(0.1605),
 'mar_small': tensor(0.0405)}


Validation: 0it [00:00, ?it/s]

{'map': tensor(0.0108),
 'map_50': tensor(0.0306),
 'map_75': tensor(0.0017),
 'map_large': tensor(0.0470),
 'map_medium': tensor(0.0227),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.0045),
 'mar_1': tensor(0.0074),
 'mar_10': tensor(0.0346),
 'mar_100': tensor(0.0892),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.2051),
 'mar_medium': tensor(0.1787),
 'mar_small': tensor(0.0474)}


Validation: 0it [00:00, ?it/s]

{'map': tensor(0.0146),
 'map_50': tensor(0.0405),
 'map_75': tensor(0.0029),
 'map_large': tensor(0.0513),
 'map_medium': tensor(0.0302),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.0061),
 'mar_1': tensor(0.0088),
 'mar_10': tensor(0.0409),
 'mar_100': tensor(0.0994),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.2174),
 'mar_medium': tensor(0.1961),
 'mar_small': tensor(0.0548)}


Validation: 0it [00:00, ?it/s]

{'map': tensor(0.0202),
 'map_50': tensor(0.0496),
 'map_75': tensor(0.0107),
 'map_large': tensor(0.0658),
 'map_medium': tensor(0.0375),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.0097),
 'mar_1': tensor(0.0101),
 'mar_10': tensor(0.0469),
 'mar_100': tensor(0.1088),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.2377),
 'mar_medium': tensor(0.2117),
 'mar_small': tensor(0.0613)}


Validation: 0it [00:00, ?it/s]

{'map': tensor(0.0240),
 'map_50': tensor(0.0605),
 'map_75': tensor(0.0115),
 'map_large': tensor(0.0787),
 'map_medium': tensor(0.0479),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.0112),
 'mar_1': tensor(0.0117),
 'mar_10': tensor(0.0524),
 'mar_100': tensor(0.1177),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.2567),
 'mar_medium': tensor(0.2261),
 'mar_small': tensor(0.0675)}


Validation: 0it [00:00, ?it/s]

{'map': tensor(0.0279),
 'map_50': tensor(0.0712),
 'map_75': tensor(0.0122),
 'map_large': tensor(0.0874),
 'map_medium': tensor(0.0562),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.0125),
 'mar_1': tensor(0.0130),
 'mar_10': tensor(0.0577),
 'mar_100': tensor(0.1257),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.2677),
 'mar_medium': tensor(0.2403),
 'mar_small': tensor(0.0731)}


Validation: 0it [00:00, ?it/s]

{'map': tensor(0.0314),
 'map_50': tensor(0.0808),
 'map_75': tensor(0.0131),
 'map_large': tensor(0.1020),
 'map_medium': tensor(0.0665),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.0138),
 'mar_1': tensor(0.0143),
 'mar_10': tensor(0.0626),
 'mar_100': tensor(0.1325),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.2847),
 'mar_medium': tensor(0.2499),
 'mar_small': tensor(0.0783)}


In [8]:
torch.save(model, "fasterrcnn_mobilenet_v3_large_fpn.pth")