In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import os
import imgaug as ia
from imgaug import augmenters as iaa
import numpy as np
import scipy
import math
from scipy import misc
from glob import glob
import json
import cv2
from wand.image import Image
from tempfile import mkstemp
%pylab inline
pylab.rcParams['figure.figsize'] = (20, 20)

In [None]:
# Input directory
indir = "/Users/Alex/Desktop/tagging_iterations"

# How many times to augment one image
augment_times = 25

# Load in images *ALL*
PROCESS_ALL_IMAGES = True
dataset = "21-07-2017"

# Setup output
outdir = "/Users/Alex/Desktop/augmented_images/%s" % dataset
if not os.path.exists(outdir):
    os.makedirs(outdir)

# Resize all our images to SCALE%; map coordinates to new scale
SCALE = 0.2

In [None]:
# Read in photos, their labels, then the image, then zip together
glob_d = "%s/%s/**/%s.jpg" % (indir, dataset, ("*" if PROCESS_ALL_IMAGES else "*248 (4)"))
labels = [json.load((open("%s.json" % p))) for p in glob(glob_d) if os.path.exists("%s.json" % p)]

# Declare temporary files which we will eventually clean up
temp_files = []
        
def photo_has_runners(label):
    # Returns true if the label has runners
    return len(label["TaggedRunners"]) > 0

# Reject labels that are not tagged
labels = [label for label in labels if photo_has_runners(label)]
image_identifiers = [label["Identifier"] for label in labels]
files_to_accept = tuple(["%s.jpg" % img_id for img_id in image_identifiers])
image_files = [image_file for image_file in glob(glob_d) if image_file.endswith(files_to_accept)]

def load_image(filename):
    # Must auto-orient (and scale) all images
    # Saves it to a temporary file that is deleted once done
    with Image(filename=filename) as img:
        img.auto_orient()
        img.resize(int(img.width * SCALE), int(img.height * SCALE))
        fp, temp_file = mkstemp()
        temp_files.append((fp, temp_file))
        print "Generating %s%% sampled version of '%s' to '%s'..." % (SCALE * 100, filename, temp_file)
        img.save(filename=temp_file)
    return misc.imread(temp_file)

def clean_temp_files():
    # Cleans all temporary files
    for (fp, temp_path) in temp_files:
        print "Deleting tempfile %s..." % temp_path
        os.close(fp)
        os.remove(temp_path)

# Load in the images
images = [load_image(filename) for filename in image_files]

def extract_bib_keypoints_on_image_from_label(label):    
    # Extracts bib keypoints from the data labels
    def extract_bib_keypoint_from_coords_str(coords_str):
        # Extracts scaled keypoints from the coords_str (i.e., "200, 300" => x=200, y=300)
        coords = [ int(int(pt) * SCALE) for pt in coords_str.split(', ') ]
        keypoint = ia.Keypoint(x=coords[0], y=coords[1])
        return keypoint
    
    def extract_bib_keypoints_from_runner(runner):
        # Extracts keypoints from specific runner
        coords = runner["Bib"]["PixelPoints"]
        return [ extract_bib_keypoint_from_coords_str(c) for c in coords ]
    
    # Extract the image
    image = images[image_identifiers.index(label["Identifier"])]
    
    # Flatten each runner down
    keypoints = np.array([ extract_bib_keypoints_from_runner(runner) for runner in label["TaggedRunners"] ]).flatten()
    
    # Return a single KeypointsOnImage
    return ia.KeypointsOnImage(keypoints, shape=image.shape)
    
# Extract all bib sheets and their respective coordinates and map to scaled matrix
keypoints = [ extract_bib_keypoints_on_image_from_label(label) for label in labels ]

In [None]:
def affine():
    # Affine transformation
    TRANSLATE_PCT_RANGE = 0.35
    ROTATION_RANGE = (-45,45)
    SHEAR_RANGE = (-5,5) 
    
    translate_percent = {
        "x": (-TRANSLATE_PCT_RANGE, +TRANSLATE_PCT_RANGE),
        "y": (-TRANSLATE_PCT_RANGE, +TRANSLATE_PCT_RANGE),
    }
    rotate=ROTATION_RANGE
    shear=SHEAR_RANGE
    mode = "edge"
    
    return iaa.Affine(translate_percent=translate_percent,
                      rotate=rotate,
                      shear=shear,
                      mode=mode)

def add_neg():
    # Applies a negative to all channels
    return iaa.Add((-45, 0))

def add_pos():
    # Applies a positive to all channels
    return iaa.Add((0, 45))

def mul_neg():
    # Multiples all channels by a negative factor
    return iaa.Multiply((0.5, 1))

def mul_pos():
    # Multiples all channels by a postive factor
    return iaa.Multiply((1, 1.5))

def blur():
    # Chooses one of three blur methods
    return one_of([
        iaa.GaussianBlur((0, 3.0)),
        iaa.AverageBlur(k=(2, 4)),
        iaa.MedianBlur(k=(3, 5)),
    ])

# Sometimes(0.5, ...) applies the given augmenter in 50% of all cases,
# e.g. Sometimes(0.5, GaussianBlur(0.3)) would blur roughly every second image.
def sometimes(aug, pct = 0.5):
    return iaa.Sometimes(pct, aug)
    
def one_of(funcs):
    # Shortcut for iaa.OneOf
    return iaa.OneOf(funcs)

seq = iaa.Sequential(
    [
        affine(),
        sometimes(one_of([add_pos(), add_neg()])),
        sometimes(one_of([mul_pos(), mul_neg()])),
        sometimes(blur(), 0.3)
    ],
    random_order=True
)

In [None]:
def keypoints_per_person(kpts):
    # Group one keypoint per person (mod 4)
    return [kpts.keypoints[i:i + 4] for i in range(0, len(kpts.keypoints), 4)]

def plot_img(src_image, src_keypoints, aug_image, aug_keypoints, aug_rects):
    # Plots image
    def plot_keypoints_on_ax(kpts, ax_id):
        polys = keypoints_per_person(kpts)
        for poly in polys:
            coords = [[coords.x, coords.y] for coords in poly]
            ax[ax_id].add_patch(patches.Polygon(coords, linewidth=3, edgecolor='lime', fill=False))
    
    fig, ax = plt.subplots(2)
    ax[0].imshow(src_image)
    ax[1].imshow(aug_image)
    plot_keypoints_on_ax(src_keypoints, 0)
    plot_keypoints_on_ax(aug_keypoints, 1)
    
    for rect in aug_rects:
        width = rect["max_x"] - rect["min_x"]
        height = rect["max_y"] - rect["min_y"]
        ax[1].add_patch(patches.Rectangle((rect["min_x"], rect["min_y"]), width=width, height=height, fill=False, linestyle="dashed", linewidth=3, color="red"))
    
    return fig

def show_results(img_idx):
    # Show results inline
    plt.close()
    image = images[img_idx]
    image_aug_keypoints = valid_keypoints(aug_keypoints[img_idx], image)
    image_aug_rects = keypoints_to_rects(image_aug_keypoints)
    image_aug_keypoints = ia.KeypointsOnImage(np.array(image_aug_keypoints).flatten(), shape=image.shape)
    return plot_img(images[img_idx], keypoints[img_idx], aug_images[img_idx], image_aug_keypoints, image_aug_rects)

In [None]:
def valid_keypoints(kpts, image):
    # Returns any keypoints that are outside the width/height of the image
    width = image.shape[1]
    height = image.shape[0]
    # Group by four (for each person)
    runner_keypoints = keypoints_per_person(kpts)
    # Copy over the "valid" keypoints (assume all are valid)
    valid_keypoints = [e for e in runner_keypoints]
    for kpts in runner_keypoints:
        for k in kpts:
            # If hidden, remove this runner
            if k.x < 0 or k.x > width or k .y < 0 or k.y > height:
                # Remove from valid if hidden
                valid_keypoints = [k for k in valid_keypoints if k is not kpts]
                break
    # Whatever remains becomes the Bib click points for these runners
    return valid_keypoints

def keypoints_to_rects(kpts):
    # Converts a set of keypoints to rectangles (min/max x/y)
    rects = []
    for kpt in kpts:
        xs = [i.x for i in kpt]
        ys = [i.y for i in kpt]
        min_x, max_x = min(xs), max(xs)
        min_y, max_y = min(ys), max(ys)
        rects.append({"min_x": min_x, "min_y": min_y, "max_x": max_x, "max_y": max_y})
    return rects

def generate_csv_for_image_kpts(image, image_identifier, kpts):
    # Writes a csv of all rectangles for this image
    image_aug_keypoints = valid_keypoints(kpts, image)
    image_aug_rects = keypoints_to_rects(image_aug_keypoints)
    lines = []
    for rect in image_aug_rects:
        line = "bib,%i,%i,%i,%i" % (rect["min_x"], rect["min_y"], rect["max_x"], rect["max_y"])
        lines.append(line)
    return "\n".join(lines)

def save_image(image, image_identifier, kpts, augment_no = "", is_augmented = True):
    unique_id = "%s_%s%s" % (image_identifier, "aug" if is_augmented else "org", augment_no)
    print "Saving %s image '%s' as '%s'..." % ("augmented" if is_augmented else "original", image_identifier, unique_id)
    imsave("%s/%s.jpg" % (outdir, unique_id), image)
    with open("%s/%s.csv" % (outdir, unique_id), "w") as csv:
        csv.write(generate_csv_for_image_kpts(image, image_identifier, kpts))

# Process (copy) all original data
org_data = dict(zip(image_identifiers, zip(images, keypoints)))
for image_identifier, data in org_data.items():
    img, kpts = data[0], data[1]
    save_image(img, image_identifier, kpts, is_augmented = False)

# Process (augment) all augmented data
for i in range(augment_times):
    # Process augment_times images
    print "Augmentation Round %i/%i..." % (i + 1, augment_times)
    seq_det = seq.to_deterministic()
    aug_images = seq_det.augment_images(images)
    aug_keypoints = seq_det.augment_keypoints(keypoints)
    aug_data = dict(zip(image_identifiers, zip(aug_images, aug_keypoints)))
    for image_identifier, data in aug_data.items():
        img, kpts = data[0], data[1]
        save_image(img, image_identifier, kpts, augment_no = i)
    if not PROCESS_ALL_IMAGES:
        # Save to test directory
        show_results(0).savefig("%s/../../augmented_images_test/%s_aug.png" % (outdir, i))
        
# Clean all temps when done!
clean_temp_files()