In [4]:
import os
import json
import glob
import torch
import cv2 as cv
import numpy as np
import torchvision
from PIL import Image

from pytorch_detection import utils
import pytorch_detection.transforms as T
from pytorch_detection.engine import train_one_epoch

In [25]:
def get_transform(train):
    transforms = []
    transforms.append(T.PILToTensor())
    transforms.append(T.ConvertImageDtype(torch.float))
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

class OCRGanttDataSet(torch.utils.data.Dataset):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms
        self.data_dir = os.path.join(self.root, 'ocr-od')
        self.classes = [
                        'deployment', 
                        'App Design', 
                        'Project Planning', 
                        'Future Enhancement', 
                        'Back-end dev', 
                        'Front-end Development', 
                        'tesing and QA', 
                        'Deadline', 
                        'Maintainance', 
                        'Back-end development'
                        ]

        self.imgs = glob.glob(os.path.join(self.data_dir, '*.jpg'))
        self.annotations = glob.glob(os.path.join(self.data_dir, '*.json'))

        print("Number of images: ", len(self.imgs))
        print("Number of annotations: ", len(self.annotations))

        all_labels = []
        for annotation_file in self.annotations:
            with open(annotation_file) as f:
                annotation_data = json.load(f)
                annotation_data = annotation_data["shapes"]
                for annotation in annotation_data:
                    annotation_label = annotation["label"]
                    all_labels.append(annotation_label)
        print("All labels: ", list(set(all_labels)))

    def __getitem__(self, idx):
        # load images and annotations
        image_file = self.imgs[idx]
        annotation_file = self.annotations[idx]

        img_path = image_file.replace("\\", "/").replace("//", "/")
        annotation_path = annotation_file.replace("\\", "/").replace("//", "/")
        img = Image.open(img_path).convert("RGB")

        boxes = []
        labels = []
        with open(annotation_path) as f:
            annotation_data = json.load(f)
            annotation_data = annotation_data["shapes"]
            for annotation in annotation_data:
                annotation_bbox = annotation["points"]
                annotation_label = annotation["label"]
                x1, y1 = annotation_bbox[0]
                x2, y2 = annotation_bbox[1]

                if (x1 > x2):
                    x1, x2 = x2, x1
                if (y1 > y2):
                    y1, y2 = y2, y1

                boxes.append([x1, y1, x2, y2])
                labels.append(self.classes.index(annotation_label))

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        image_id = torch.tensor([idx])

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

In [26]:
num_classes = 10

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)

In [27]:
num_classes = 10


model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)

dataset  = OCRGanttDataSet(root="data", transforms=get_transform(train=True))

data_loader = torch.utils.data.DataLoader(
                                        dataset, 
                                        batch_size=2, 
                                        shuffle=True, 
                                        collate_fn=utils.collate_fn
                                        )
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = torch.device('cpu')
print("CONSUMING GPU" if device == torch.device('cuda') else "CONSUMING CPU")

Number of images:  49
Number of annotations:  49
All labels:  ['deployment', 'App Design', 'Project Planning', 'Future Enhancement', 'Back-end dev', 'Front-end Development', 'tesing and QA', 'Deadline', 'Maintainance', 'Back-end development']
CONSUMING CPU


In [28]:
model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, 
                            lr=0.005,
                            momentum=0.9, 
                            weight_decay=0.0005
                            )
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=3,
                                                gamma=0.1)

num_epochs = 10

for epoch in range(num_epochs):
    train_one_epoch(
                    model, 
                    optimizer, 
                    data_loader, 
                    device, 
                    epoch, 
                    print_freq=10
                    )
    lr_scheduler.step()

Epoch: [0]  [ 0/25]  eta: 0:00:24  lr: 0.000213  loss: 3.8630 (3.8630)  loss_classifier: 2.3937 (2.3937)  loss_box_reg: 0.0770 (0.0770)  loss_objectness: 1.1723 (1.1723)  loss_rpn_box_reg: 0.2200 (0.2200)  time: 0.9982  data: 0.0400  max mem: 0
Epoch: [0]  [10/25]  eta: 0:00:23  lr: 0.002294  loss: 1.9089 (2.3025)  loss_classifier: 1.2971 (1.4511)  loss_box_reg: 0.2964 (0.2432)  loss_objectness: 0.1451 (0.4542)  loss_rpn_box_reg: 0.1263 (0.1540)  time: 1.5987  data: 0.0195  max mem: 0
Epoch: [0]  [20/25]  eta: 0:00:08  lr: 0.004376  loss: 1.2268 (1.6657)  loss_classifier: 0.5034 (0.9675)  loss_box_reg: 0.3018 (0.2694)  loss_objectness: 0.1288 (0.2885)  loss_rpn_box_reg: 0.1061 (0.1402)  time: 1.7751  data: 0.0173  max mem: 0
Epoch: [0]  [24/25]  eta: 0:00:01  lr: 0.005000  loss: 1.0209 (1.5610)  loss_classifier: 0.4534 (0.8947)  loss_box_reg: 0.3239 (0.2879)  loss_objectness: 0.0968 (0.2508)  loss_rpn_box_reg: 0.0933 (0.1276)  time: 1.8279  data: 0.0167  max mem: 0
Epoch: [0] Total tim

In [29]:
torch.save(model.state_dict(), 'weights/ocr-gantt.pt')

In [30]:
def inference_image(image_file):
    model.eval()
    img = Image.open(image_file).convert("RGB")
    transform = get_transform(train=False)
    img, _ = transform(img, None)
    img = img.unsqueeze(0)
    img = img.to(device)
    outputs = model(img)
    outputs = [{k: v.to(torch.device('cpu')) for k, v in t.items()} for t in outputs]
    
    boxes = outputs[0]['boxes'].detach().numpy()
    labels = outputs[0]['labels'].detach().numpy()
    scores = outputs[0]['scores'].detach().numpy()

    img = cv.imread(image_file)
    img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
    for box, label, score in zip(boxes, labels, scores):
        if score > 0.5:
            x1, y1, x2, y2 = box
            cv.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
            cv.putText(img, dataset.classes[label], (int(x1), int(y1)), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
    cv.imwrite('output.jpg', img)

In [31]:
inference_image('data/ocr-od/GC_1.jpg')