In [None]:
import os
import cv2
import sys
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import PIL
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data_utils
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.rpn import AnchorGenerator

In [None]:
class BrailleDataset(torch.utils.data.Dataset):
  def __init__(self, images, labels):
    """
    Args:
        images (string): Directory with all the images.
        labels (string): Path to a csv file with labels.
    """
    df = pd.read_csv(labels, header=None, names=['sheet', 'x1', 'y1', 'x2', 'y2', 'symbol'])
    df['symbol'] = df['symbol'].apply(lambda x: x.replace('[\'', ''))
    df['symbol'] = df['symbol'].apply(lambda x: x.replace('\']', ''))
    classes = df.symbol.unique()
    self.labels_dict = dict(zip(classes, [x for x in range(45)]))
    self.labels = df
    self.images = images
    
  def get_classname(self, value):
    for classname, class_number in self.labels_dict.items():
        if value == class_number:
            return classname
    return 'null'

  def __len__(self):
    return len(self.labels.sheet.unique())

  def __getitem__(self, idx):
    idx += 1
    image = cv2.imread(self.images + '/' + str(idx) + '.png', 0)
    image = image / 255
    image_id = torch.tensor([idx])
    label = self.labels[self.labels['sheet'] == idx]
    
    # Each label consists of fields: sheet number, x1, y1, x2, y2, label.
    # Boxes shape: (n_objects, 4).
    boxes = np.array(label[['x1', 'y1', 'x2', 'y2']]).reshape((-1, 4))
    boxes = torch.as_tensor(boxes, dtype=torch.float32)
    
    # Labels shape: (n_objects,).
    labels = label[['symbol']]
    labels = np.array([self.labels_dict.get(labels.iloc[i][0]) for i in range(len(labels))])
    labels = torch.as_tensor(labels, dtype=torch.int64)
    
    target = { "boxes" : boxes, "labels" : labels, "image_id" : image_id }

    return image, target

In [None]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True, box_detections_per_img=500, box_score_thresh=0.4)

num_classes = 45
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

In [None]:
def collate_fn(batch):
    """
    Describes how to combine tensors of different sizes (due to different number of objects on images).
    """
    return tuple(zip(*batch))

In [None]:
train_size = 3000
test_size = 1000

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

dataset = BrailleDataset('../input/sheets-w-noise/sheets_w_noise', '../input/sheets-labels-w-noise/sheets_labels_w_noise.csv')

# Shuffle indices.
indices = torch.randperm(len(dataset)).tolist()
# Slicing dataset into train and test.
dataset_train = torch.utils.data.Subset(dataset, indices[:train_size])
dataset_test = torch.utils.data.Subset(dataset, indices[-test_size:])

train_loader = torch.utils.data.DataLoader(
    dataset_train, batch_size=2, shuffle=True, collate_fn=collate_fn)

test_loader = torch.utils.data.DataLoader(
    dataset_test, batch_size=2, shuffle=False, collate_fn=collate_fn)

model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params, lr=0.0001)

In [None]:
print(dataset.labels_dict)

In [None]:
import random

colors = []
for i in range(num_classes):
    r = random.randint(0, 255)
    g = random.randint(0, 255)
    b = random.randint(0, 255)
    colors.append((r, g, b))

colors_classes = dict(zip([i for i in range(num_classes)], colors))

In [None]:
def draw_boxes(boxes, labels, scores, image, epoch, image_id):
    """
    Drawing boxes on the original image based on the network predictions. 
    """
    model.rpn_score_thresh = 0.8
    image = cv2.cvtColor(image.astype('float32'), cv2.COLOR_GRAY2BGR)
    for i, box in enumerate(boxes):
        class_number = labels[i]
        color = colors_classes.get(class_number)
        cv2.rectangle(
            image,
            (int(box[0]), int(box[1])),
            (int(box[2]), int(box[3])),
            color, 2
        )
        cv2.putText(image, str(class_number) + ' ' + str(round(scores[i], 2)), (int(box[0]), int(box[1]-5)),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2, 
                    lineType=cv2.LINE_AA)
    ime = PIL.Image.fromarray((image * 255).astype(np.uint8))
    ime.save(str(image_id.numpy()[0]) + '_Epoch:' + str(epoch) + '.png')


In [None]:
def get_iou(boxA, boxB):
    """
    Calculating Intersection Over Union metric.
    Args:
        boxA (float array): ground truth box,
        boxB (float array): predicted box.
    """
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])

    inter_area = abs(max((xB - xA, 0)) * max((yB - yA), 0))
    
    if inter_area == 0:
        return 0
    
    boxA_area = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
    boxB_area = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))

    iou = inter_area / float(boxA_area + boxB_area - inter_area)

    return iou

In [None]:
def get_metrics(gt_boxes, pred_boxes, target, labels):
    """
    Calculating average Intersection Over Union, average classifiaction accuracy and correct detections percent. 
    """
    gts = gt_boxes.tolist()
    preds = pred_boxes.tolist()
    trgs = target["labels"].numpy().tolist()
    lbls = labels.tolist()
    
    correct_clf = 0.0
    correct_det = 0.0
    ious = []
    
    predictions = list(zip(preds, lbls))
    targets = list(zip(gts, trgs))
    
    for (pred_box, l) in predictions:
        
        for (gt_box, t) in targets:
            
            iou = get_iou(gt_box, pred_box)
            
            if iou >= 0.5:
                
                correct_det += 1
                
                if l == t:
                    correct_clf += 1
                    
                ious.append(iou)
                targets.remove((gt_box, t))
                
    avg_iou = sum(ious) / len(ious)
    clf_accuracy = correct_clf / len(preds)
    correct_detections = correct_det / len(gts)
    
    return avg_iou, clf_accuracy, targets, correct_detections

In [None]:
def evaluate(model, epoch):
    """
    Model evaluation for current epoch (drawing boxes, calculating metrics). 
    """
    i = 0
    draw = True
    ious = []
    accuracies = []
    undetected = []
    detected = []
    
    for images, targets in test_loader:
        
        # Draw boxes only for first 3 samples.
        if i == 3:
            draw = False
            
        original_images = images    
        images = list(torch.from_numpy(image).float().reshape((1, 2000, 1500)).to(device) for image in images)
        outputs = model(images)
        
        original_images = list(original_images)
        targets = list(targets)
        
        for image, output, target in zip(original_images, outputs, targets):
            
            gt_boxes = target["boxes"].numpy()
            pred_boxes = output["boxes"].detach().cpu().numpy()
            labels = output["labels"].detach().cpu().numpy()
            scores = output["scores"].detach().cpu().numpy()
        
            if draw:
                draw_boxes(pred_boxes, labels, scores, image, epoch, target["image_id"])
                
            avg_iou, clf_accuracy, undetected_objects, det_obj_percent = get_metrics(gt_boxes, pred_boxes, target, labels)
            
            ious.append(avg_iou)
            accuracies.append(clf_accuracy)
            undetected.append(undetected_objects)
            detected.append(det_obj_percent)
        
        i += 1
    
    # Calculating average metrics throughout the whole epoch.
    average_iou = sum(ious) / len(ious)
    average_accuracy = sum(accuracies) / len(accuracies)
    predicted_boxes = sum(detected) / len(detected)
        
    return average_iou, average_accuracy, predicted_boxes, undetected

In [None]:
num_epochs = 5
do_eval = True

all_losses = []
clf_losses = []
bbox_losses = []
avg_ious = []
avg_accs = []
pred_boxes_num = []
undetected_objs = []

best_boxes = 0.0

In [None]:
# Training model.

best_model_wts = copy.deepcopy(model.state_dict())

for epoch in range(num_epochs):
        
    model.train()
    
    running_loss_cls = 0.0
    running_loss_bbox = 0.0
        
    start = time.time()
        
    for images, targets in train_loader:
            
        images = list(torch.from_numpy(image).float().reshape((1, 2000, 1500)).to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
        loss_dict = model(images, targets)
            
        loss_classifier = loss_dict["loss_classifier"]
        loss_box_reg = loss_dict["loss_box_reg"]      

        losses = sum(loss for loss in loss_dict.values())

        loss_value = losses.item()
        running_loss_cls += loss_classifier
        running_loss_bbox += loss_box_reg                                                        

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
            
    t = time.time() - start
        
    all_losses.append(loss_value)
    clf_losses.append(running_loss_cls / train_size)
    bbox_losses.append(running_loss_bbox / train_size)
        
    print('Epoch: {}, Loss_classifier: {:.4f}, Loss_box_reg: {:.4f}, Train time: {:.4f} min'.format(epoch, running_loss_cls / train_size, running_loss_bbox / train_size, t / 60))
        
    if do_eval:
        model.eval()
        average_iou, average_accuracy, predicted_boxes, undetected = evaluate(model, epoch)
        avg_ious.append(average_iou)
        avg_accs.append(average_accuracy)
        pred_boxes_num.append(predicted_boxes)
        undetected_objs.append(undetected)
        print('Epoch: {}, Average IoU: {:.4f}, Average accuracy: {:.4f}, Average predicted boxes percent: {:.4f}'.format(epoch, average_iou, average_accuracy, predicted_boxes))

        if best_boxes < predicted_boxes:
            best_boxes = predicted_boxes
            best_model_wts = copy.deepcopy(model.state_dict())

In [None]:
# Saving best model weights.

PATH = "state_dict_model.pt"
torch.save(best_model_wts, PATH)

In [None]:
numbers = { 'а': 1, 'б': 2, 'ц': 3, 'д': 4, 'е': 5, 'ф': 6, 'г': 7, 'х': 8, 'и': 9, 'ж': 0 }

In [None]:
def image_to_text(pred_boxes, labels, n):
    """
    Translating text in Braille to text in Russian based on network predictions.
    Args:
        pred_boxes (float array): predicted coordinates for letters,
        labels (int array): predicted letters,
        n (int): image serial number.
    """
    text = ""
    pred_boxes = pred_boxes.tolist()
    labels = labels.tolist()
    
    fname = str(n) + '.txt'
    f = open(fname, "w")
    
    is_number = False
    is_capital = False
    
    # Networks returns predictions in random order so we need to
    # sort boxes and group them (one group corresponds to one line of text).
    
    # Sorting predicted boxes by y1 (each box consists of [x1, y1, x2, y2]).
    sorted_by_y = sorted(list(zip(pred_boxes, labels)), key = lambda x: x[0][1])
    
    lines = []
    line = []
    is_new_line = False
    prev_box_y1 = sorted_by_y[0][0][1]
    
    for sym in sorted_by_y:
        
        curr_box_y1 = sym[0][1]
        
        # If distance between two symbols (vertically) is greater than approximate 
        # symbol height, this means we have a new line.
        if abs(prev_box_y1 - curr_box_y1) > 50:
            is_new_line = True
            
        prev_box_y1 = curr_box_y1
        
        if is_new_line and len(line) > 0:
            lines.append(line)
            line = []
            is_new_line = False
            
        line.append(sym)
        
    # All symbols have been grouped into "lines".
    lines.append(line)
        
    for l in lines:
        
        # Sorting symbols in each line by x1 to recreate letters order.
        sorted_by_x = sorted(l, key = lambda x: x[0][0])
        
        prev_box_x1 = sorted_by_x[0][0][0]
        
        # Putting the text together.
        for box, label in sorted_by_x:
            
            curr_box_x1 = box[0]
            curr_box_x2 = box[2]
        
            symbol = dataset.get_classname(label)
            
            if symbol == 'цифровой символ':
                is_number = True
            
            if symbol == 'знак заглавной буквы':
                is_capital = True
        
            if abs(prev_box_x1 - curr_box_x1) > 80: 
                if is_number == True:
                    is_number = False
                text = text + " "
            
            if is_capital and symbol != 'знак заглавной буквы':
                symbol = symbol.upper()
            elif is_number and symbol != 'цифровой символ':
                symbol = str(numbers.get(symbol))
            
            is_capital = False
        
            if symbol != 'цифровой символ' and symbol != 'знак заглавной буквы':
                text = text + symbol
        
            prev_box_x1 = curr_box_x1
            
        text += '\n'
    
    f.write(text)
    f.close()