In [1]:
import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision import transforms, utils
import pathlib
import json

In [2]:
class PubLayNetDataset(Dataset):
    def __init__(self, root, new_size = (224, 224), resize=False):  # new_size format: (W,H)
        self.root = root
        self.new_size = new_size
        self.resize = resize
        # load all image files, sorting them
        self.imgs = list(sorted(os.listdir(os.path.join(root, "Images"))))
        # .json file path
        self.json = os.path.join(root, "samples.json")
        
        with open(self.json) as f:
            self.templates = json.load(f)
            f.close()
        # images format: {pic id: {'file_name': '...', 'height': ..., 'width': ..., 'annotations': [{...}, ...]}}
        self.images = {}
        for image in self.templates['images']:
            self.images[image['id']] = {'file_name': image['file_name'], 'height' : image['height'], 
                                        'width' : image['width'], 'annotations': []}
        for ann in self.templates['annotations']:
            self.images[ann['image_id']]['annotations'].append(ann)
        #pictures id's
        self.keys = list(self.images.keys())
        
     
    def __getitem__(self, idx):
        # load image
        img_path = os.path.join(self.root, "Images", self.imgs[idx])
        img = Image.open(img_path).convert("RGB")        
        # get objects
        objects_key = 0
        for key in self.keys:
            if self.images[key]['file_name'] == self.imgs[idx]:
                objects_key = key
                break
        # get image size
        self.size = (self.images[objects_key]['width'], self.images[objects_key]['height'])
        objects = self.images[objects_key]['annotations']
        boxes = []
        labels = []
        for object_ in objects:
            # get the label and bboxes
            label = object_['category_id']
            labels.append(label)
            bbox = object_['bbox']
            bbox_max_min = [min(bbox[0], bbox[2]), min(bbox[1], bbox[3]),  # [x_min, y_min, x_max, y_max]
                              max(bbox[0], bbox[2]), max(bbox[1], bbox[3])]
            boxes.append(bbox_max_min)        
 
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)        
 
        image_id = torch.as_tensor([objects_key])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # all instances are not crowd
        iscrowd = torch.zeros((len(objects),), dtype=torch.int64)
 
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd
        
        if self.resize:
            img, target = self._prepare_sample(img, target)
            
        transform = transforms.Compose([
                transforms.ToTensor(),
            ])
        img = transform(img)
        
        return img, target
        

    def __len__(self):
        return len(self.imgs)
    
    
    def _prepare_sample(self, image, target):
        image = image.resize(self.new_size, Image.ANTIALIAS)
        boxes = target["boxes"]
        y_ratio = self.size[1] / self.new_size[1]
        x_ratio = self.size[0] / self.new_size[0]
        boxes[:, 0] /= x_ratio
        boxes[:, 2] /= x_ratio
        boxes[:, 1] /= y_ratio
        boxes[:, 3] /= y_ratio
        target["boxes"] = boxes
        target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        return np.array(image), target

In [3]:
root = '/media/kirb/ADATA HD680/examples/'

In [4]:
def get_object_detection_model(num_classes):
    # load an object detection model pre-trained on COCO
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    
    # replace the classifier with a new one, that has num_classes which is user-defined
    num_classes = num_classes  # 5 class + background
 
    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
 
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
 
    return model

In [None]:
from engine import train_one_epoch, evaluate
import utils

batch_size = 2

# train on the GPU or on the CPU, if a GPU is not available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# 6 classes: 5+background
num_classes = 6
# use our dataset and defined transformations
publaynet = PubLayNetDataset(root, new_size=_, resize=False)

# split the dataset in train and val set

indices = torch.randperm(len(publaynet)).tolist()
dataset = torch.utils.data.Subset(publaynet, indices[:16])
dataset_val = torch.utils.data.Subset(publaynet, indices[16:])

# define training and validation data loaders
# Num_when training models in jupyter notebook The workers parameter can only be 0, otherwise an error will occur, which is commented out here
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True, # num_workers=4,
    collate_fn=utils.collate_fn)

data_loader_val = torch.utils.data.DataLoader(
    dataset_val, batch_size=batch_size, shuffle=False, # num_workers=4,
    collate_fn=utils.collate_fn)

# get the model using our helper function
model = get_object_detection_model(num_classes)

# move model to the right device
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]

# SGD
optimizer = torch.optim.SGD(params, lr=0.0003,
                            momentum=0.9, weight_decay=0.0005)

# and a learning rate scheduler
# cos learning rate
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=1, T_mult=2)

# let's train it for   epochs
num_epochs = 3

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    # Engine.pyTrain_ofOne_The epoch function takes both images and targets. to(device)
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=50)

    # update the learning rate
    lr_scheduler.step()

    # evaluate on the test dataset    
    evaluate(model, data_loader_val, device=device)    
    
    print('')
    print('==================================================')
    print('')

print("That's it!")