## Data augmentation with bbox label 

In [15]:
import os
from matplotlib.patches import Rectangle
%matplotlib inline
import cv2
from matplotlib import pyplot as plt
import albumentations as A
from pathlib import Path


## Conversion of Yolo Labels to Coco/VOC

In [16]:
def yolo_to_coco(bbox, img_width, img_height):
    """
    yolo: [x_centre, y_centre, width, height], value:(0-1)
    coco: [x_min, y_min, width, height], value:(real numbers)
    """
    x_centre, y_centre, w, h = bbox
    
    box_width = w * img_width
    box_height = h * img_height
    
    x_min = int(x_centre * img_width - (box_width/2))
    y_min = int(y_centre * img_height - (box_height / 2))

    return x_min, y_min, box_width, box_height

def yolo_to_voc(bbox, img_width, img_height):
    """
    yolo: [x_centre, y_centre, width, height], value:(0-1)
    voc: [x_min, y_min, x_max, y_max], value:(real numbers)
    """
    x_centre, y_centre, w, h = bbox
    
    box_width = w * img_width
    box_height = h * img_height
    
    x_min = int(x_centre * img_width - (box_width/2))
    y_min = int(y_centre * img_height - (box_height / 2))

    x_max = x_min + box_width
    y_max = y_min + box_height

    return x_min, y_min, x_max, y_max

## Preprocessing text files 
### - Separate class_id and bbox coordinates

In [17]:
def processing_text_files(txt_file):
    """
    Change format from : 1 0.5 0.3 0.2 0.5 - class x_centre y_centre width height
    to : class_id=1; bbox_coordinates=[0.5,0.3,0.2,0.5]
    """
    with open(txt_file,'r') as f:
        lines = f.readline()
        lines = lines.replace(' ', ',')[:-1]
        class_id = int(lines.split(',',1)[0])
        # print(class_id)
        bbox = lines.split(',',1)[1]
        if bbox[-1] == ',':
            bbox = bbox[:-1]
        bbox = list(bbox.split(','))
        bbox = [float(i) for i in bbox]
    
    return class_id, bbox

## Visualize and save the augmented image and text file to a separate folder

In [21]:
def visualize_save_image(path_save_image, image_id, transform, image, bboxes, bbox_classes, image_width, image_height, number_plots=5):
    fig, ax = plt.subplots(2, 3, figsize=(15, 10))
    count = 0
    for i in range(number_plots):
        count += 1

        ## Augmentation done here for each image
        transformed = transform(
            image=image, 
            bboxes=bboxes, 
            bbox_classes=bbox_classes
        )

        # print(transformed)
        
        ## Bbox conversion
        bboxes_convert = transformed['bboxes'][0]
        x_min, y_min, box_width, box_height = yolo_to_coco(bboxes_convert, image_width, image_height)

        ax[i // 3, i % 3].imshow(transformed["image"])
        bbox_rect = Rectangle(
            (x_min,y_min),
            box_width,
            box_height,
            linewidth=5,
            edgecolor="r",
            facecolor="none",
        )
        ax[i // 3, i % 3].add_patch(bbox_rect)

        # ## Save augmented image 
        # save_image(path_save_image, count, transformed, image_id)
        
        # # Save augmented text image
        # save_txtfile(path_save_image, count, transformed, image_id)

    plt.show() # Comment this to not show image

def save_image(path_save_image, count, transformed, image_id, extension='.jpg'):
    ## Save augmented image 
    saveImage = os.path.join(path_save_image, image_id + str(count) + extension)
    cv2.imwrite(saveImage, transformed['image'])

def save_txtfile(path_save_image, count, transformed, image_id, extension='.txt'):
    ## Save augmented text file 
    save_bbox = transformed['bboxes']
    save_class = transformed['bbox_classes']
    
    ## pre-processing format to be saved
    for bboxx, bbox_class in zip(save_bbox, save_class):
        for key, value in category_id_to_name.items():
            if value == bbox_class:
                save_class = key
        save_label = str(save_class)
        save_label += str(bboxx)
        save_label = save_label.replace('(',' ')
        save_label = save_label.replace(')',' ')
        save_label = save_label.replace(',','')

    saveLabel = os.path.join(path_save_image, image_id + str(count) + extension)
    with open(saveLabel, 'w') as txt:
        txt.writelines(save_label)

## Augmentation

In [22]:
def augment_with_label(path, path_save_image, ext=('.jpg','.jpeg','.JPG'), object_class=['positive']):
    """
    path : path to folder of images and text files to be augmented
    path_save_image : path to output augmented images and text files
    ext : accepted extension for image
    object_class : change according to the class_name to be augmented"""

    # List of image and txt files in path folder    
    image_file = []
    txt_file = []
    for files in os.listdir(path):
        if files.endswith(ext):
            image_file.append(files)
        else:
            txt_file.append(files)
    
    # Parallel iteration over image and txt file (IMG1.jpg IMG1.txt)
    for img, txt in zip(image_file, txt_file):
        
        # Load image
        image = cv2.imread(path+img)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # Extract width and height of image for bbox conversion
        image_height, image_width, channel = image.shape

        # Load text files and processing 
        class_id, bbox = processing_text_files(path+txt)

        # Data Augmentation
        transform = A.Compose(
            [A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.4),
            A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45,p=0.5),
            A.RandomBrightnessContrast(p=0.3),
            A.Blur(blur_limit=5),
            A.HueSaturationValue(p=0.3)],
            bbox_params=A.BboxParams(format='yolo', label_fields=['bbox_classes']))

        bboxes = [bbox] #need to be in list
        bbox_classes = object_class #class of the object
        
        # Show and save image
        image_id = Path(img).stem + '_'
        visualize_save_image(path_save_image, image_id, transform, image, bboxes, bbox_classes, image_width, image_height, number_plots=5)
        

## Run augmentation

In [None]:
path = '/path/to/folder/with/images/to/augment'
path_save_image = 'path/to/save/augmented/images'
category_id_to_name = {0: 'positive', 1:'negative'}

augment_with_label(path, path_save_image)