# GADGET 2 Image Augmentation
See the [GADGET 2 GitHub Wiki](https://github.com/Jaros24/GADGET2/wiki/Image-Augmentation) information on how to use this notebook

In [1]:
image_dir = '/mnt/analysis/e17023/Adam/GADGET2/Output/images/' # directory of images
output_dir = '/mnt/analysis/e17023/Adam/GADGET2/Output/aug_images/' # directory to save augmented images
duplication_factor = 50 # number of times to duplicate each image

# energy bar parameters
evar = 0.1 # error of the energy bar (between 0 and 1)

# trace parameters
trace_scale = 1.5 # scale trace randomly between 1/trace_scale and trace_scale (1 is no scaling)
trace_mirror = True
placement_error = 3 # number of pixels to randomly shift trace x axis

# track parameters
track_scale = 1.4 # scale track randomly between 1/track_scale and track_scale (1 is no scaling)
track_blur = 0 # blur track by random factor between 0 and track_blur (0 is no blur)
noise_track = 0.3 # noise to add to non-zero pixels in track (0 for no noise)
track_edge = 2 # number of pads near track to consider for edge noise
edge_noise = 0.5 # amount of noise to add to edges (0 for no noise)
location_shuffle = True # randomly place track on padplane? (otherwise place in center)
veto_radius = 66.709 # leave 66.709 for realistic, or between 0 and 72
rotate_track = True # randomly rotate track
mirror_track = True # randomly mirror track
max_iters = 100 # maximum number of iterations to try to place track on padplane before giving up

In [2]:
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import random
import math
import os, sys, io
from PIL import Image
from tqdm import tqdm
import time
from PIL import ImageFilter

In [3]:
# pre-processing
image_list = os.listdir(image_dir) # list of images
random.seed(time.time()) 
# generate circle mask for padplane validation
circle_mask = np.ones((145,145))
for i in range(145):
    for j in range(145):
        if math.sqrt((i-72)**2 + (j-72)**2) <= veto_radius:
            circle_mask[i,j] = 0

In [4]:
def track_rescale(track, scale, blur):
    track = Image.fromarray(track)
    width, height = track.size
    scale = random.uniform(1/scale, scale)
    new_width = int(width * scale)
    new_height = int(height * scale)
    
    track = track.resize((new_width, new_height))
    track = track.filter(ImageFilter.GaussianBlur(random.uniform(0,blur)))

    track = np.array(track)
    return track

In [5]:
def track_noise(track, evar):
    pad_std = np.std(track[track > 0])
    track[track > 0] += np.random.normal(0, pad_std*evar, track[track > 0].shape)
    return track

In [6]:
def aug_padplane(image, circle_mask=circle_mask, track_scale=track_scale, track_blur=track_blur, noise_track=noise_track, edge_range=track_edge, edge_firing=edge_noise):
    max_iters = 100 # maximum number of iterations to find a valid padplane location
    
    # extract padplane from image
    padplane_bounds = ((3,40),(148,185))
    padplane = image[padplane_bounds[0][0]:padplane_bounds[1][0], padplane_bounds[0][1]:padplane_bounds[1][1], :]
    
    track = padplane.copy()
    track[track == 255] = 0 # remove white background

    track = track[1:, 1:] 
    track = track[::4,::4] # downsample track to remove grid lines

    # extract relative energy of pads
    track = np.average([track[:,:,0].astype(np.float32) / 205, track[:,:,1].astype(np.float32) / 240], axis=0)

    pad_threshold = min(track[track > 0])
    track = np.pad(track, ((10,10),(10,10)), 'constant', constant_values=0) # pad track to prevent edge effects
    
    # PIXEL BASED AUGMENTATIONS
    # rescale track and blur
    if track_scale != 1:
        track = track_rescale((track*255).astype(np.uint8), track_scale, track_blur).astype(np.float32) / 255
    
    # directly add noise to track
    track = track_noise(track, noise_track)
    
    # add random firing pixels near edge of track
    # determine what pixels are within edge_range of nonzero pixels
    edge_pixels = np.zeros(track.shape)
    for i in range(edge_range, track.shape[0] - edge_range):
        for j in range(edge_range, track.shape[1] - edge_range):
            if track[i,j] > 0:
                edge_pixels[i-edge_range:i+edge_range, j-edge_range:j+edge_range] = 1
    edge_pixels = edge_pixels * (1 - (track > 0).astype(bool)) # remove pixels that are already firing
    track += np.random.normal(0, edge_firing * pad_threshold, track.shape) * edge_pixels
    
    
    # reapply threshold to track
    track[track < pad_threshold] = 0
    track[track >= 1] = 1

    # reconstruct track and crop
    track = track.repeat(4, axis=0).repeat(4, axis=1) # upsample track to original size
    track_bounds = np.where(track != 0)
    track_bounds = ((min(track_bounds[0]), max(track_bounds[0])+1), (min(track_bounds[1]), max(track_bounds[1])+1))
    track = track[track_bounds[0][0]:track_bounds[0][1], track_bounds[1][0]:track_bounds[1][1]]
    
     # redraw gridlines
    for i in range(1, track.shape[0]//4):
        track[i*4, :] = 0
    for i in range(1, track.shape[1]//4):
        track[:, i*4] = 0
    track = track[1:, 1:]
    
    # recolor track
    track = np.stack((track*204, track*240, (track > 0) * (track != 1) * 255), axis=2).astype(np.uint8)
    
    # place track in random valid location on padplane
    in_bounds = False
    iters = 0
    while not in_bounds:
        padplane[:,:,0] = padplane[:,:,2]; padplane[:,:,1] = padplane[:,:,2] # blank padplane
        if location_shuffle:
            try:
                loc = (random.randint(0, 36 - (track.shape[0] + 1) // 4), random.randint(0, 36 - (track.shape[1] + 1)//4)) # random location

                # insert track into padplane    
                padplane[loc[0]*4+1:loc[0]*4+track.shape[0]+1, loc[1]*4+1:loc[1]*4+track.shape[1]+1,0] += track[:,:,0] # red
                padplane[loc[0]*4+1:loc[0]*4+track.shape[0]+1, loc[1]*4+1:loc[1]*4+track.shape[1]+1,1] += track[:,:,1] # green

                # test for track within radius of padplane
                test_padplane = padplane.copy()
                test_padplane *= circle_mask[:,:,np.newaxis].astype(padplane.dtype)
                if np.sum(test_padplane[:,:,0]) == np.sum(test_padplane[:,:,2]):
                    in_bounds = True
                else: # track is outside radius, try again
                    iters += 1
                    if iters > max_iters:
                        return None
            except:
                iters += 1
                if iters > max_iters:
                    return None
        else: # center track on padplane
            loc = (17 - track.shape[0]//8, 17 - track.shape[1]//8) # center shifted by half-track size
            padplane[loc[0]*4+1:loc[0]*4+track.shape[0]+1, loc[1]*4+1:loc[1]*4+track.shape[1]+1,0] += track[:,:,0]
            padplane[loc[0]*4+1:loc[0]*4+track.shape[0]+1, loc[1]*4+1:loc[1]*4+track.shape[1]+1,1] += track[:,:,1]
            
            test_padplane = padplane.copy()
            test_padplane *= circle_mask[:,:,np.newaxis].astype(padplane.dtype)
            if np.sum(test_padplane[:,:,0]) == np.sum(test_padplane[:,:,2]):
                in_bounds = True
            else:
                return None # track doesn't fit in padplane
    
    # randomly rotate padplane
    if mirror_track:
        if random.randint(0,1):
            padplane = np.flip(padplane, 0)
    if rotate_track:
        padplane = np.rot90(padplane, random.randint(0,3), (0,1))

    
    # place padplane back into image
    image[padplane_bounds[0][0]:padplane_bounds[1][0], padplane_bounds[0][1]:padplane_bounds[1][1], :] = padplane
    return image

In [7]:
def aug_trace(image, trace_scale=trace_scale):
    trace = image[151:,:,0] # extract trace from image
    
    trace = np.sum(255-trace, axis=0).astype(np.int64) # cumulative sum of trace
    
    # find most common non-zero value in trace
    trace_zero = np.bincount(trace[trace > 0]).argmax()
    
    # determine edges of trace
    trace_edges = np.where(trace == trace_zero)
    trace_edges = (trace_edges[0][0], trace_edges[0][-1])
    trace_width = trace_edges[1] - trace_edges[0]
    
    trace[:trace_edges[0]] = trace_zero
    trace[trace_edges[1]:] = trace_zero
    
    trace = trace - trace_zero # set baseline to zero
    
    # peak value in trace_height
    peakx = np.argmax(trace)
    x_trace = np.arange(trace.shape[0]).astype(np.float32)
    
    scale_factor = trace_scale**random.uniform(-1,1) # scale trace randomly between 1/trace_scale and trace_scale
    
    
    if trace_mirror:
        mirror = random.choice([-1,1]) # randomly mirror trace about peakx
    else:
        mirror = 1

    x_trace = (x_trace - peakx)*scale_factor + peakx # scale trace about peakx
    
    x_trace = x_trace + random.uniform(-placement_error, placement_error) # randomly shift trace x axis
    
    # crop trace
    crop0 = np.where(x_trace > peakx - trace_width//2 - 1)[0][0]
    crop1 = np.where(x_trace < peakx + trace_width//2 + 1)[0][-1]
    
    x_trace = x_trace[crop0:crop1]
    trace = trace[crop0:crop1]
    
    if mirror == -1:
        x_trace = np.flip(x_trace, 0)
        trace = np.flip(trace, 0)
    
    # export trace as jpg to be loaded as matrix (same process as generating trace originally)
    my_dpi = 96
    fig_size = (224/my_dpi, 73/my_dpi)  # Fig size to be used in the main thread
    fig, ax = plt.subplots(figsize=fig_size)
    ax.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.clear()
    x = np.linspace(0, len(trace)-1, len(trace))
    ax.fill_between(x, trace, color='b', alpha=1)
    buf = io.BytesIO()
    fig.savefig(buf, format='png', dpi=my_dpi)
    fig.clf()
    buf.seek(0)
    with Image.open(buf) as im:
        trace_img_png = np.array(im)
    buf.close()
    
    image[151:,:,:] = trace_img_png[:,:,:3]
    return image

In [8]:
def aug_ebar(image, evar = evar):
    # extract energy bar from image
    ebar_bounds = ((5,8),(145,17))
    ebar = image[ebar_bounds[0][0]:ebar_bounds[1][0], ebar_bounds[0][1]:ebar_bounds[1][1], :]

    ebar_slice = np.array([np.mean(ebar[i,1,:]) for i in range(ebar.shape[0])]) # 1d slice of energy bar
    for i in range(ebar_slice.shape[0]):
        if ebar_slice[i] != 255:
            break
    proportion_filled = 1 - (i-1)/ebar_slice.shape[0]
    proportion_filled *= np.random.uniform(1-evar, 1+evar)

    image[ebar_bounds[0][0]:ebar_bounds[1][0], ebar_bounds[0][1]:ebar_bounds[1][1], :] = 255
    image = fill_energy_bar(image, proportion_filled)
    return image

def blue_range(pad_plane, rows):
	start_row = 140
	low_color = 0
	high_color = 35
	for i in range(rows):
		pad_plane[start_row:start_row+5, 8:17, 0] = low_color
		pad_plane[start_row:start_row+5, 8:17, 1] = high_color
		start_row = start_row - 5 
		low_color = low_color + 35
		high_color = high_color + 35
	return pad_plane
def yellow_range(pad_plane, rows):
	start_row = 105
	color = 220
	for i in range(rows):
		pad_plane[start_row:start_row+5, 8:17, 2] = color
		start_row = start_row - 5 
		color = color - 15
	return pad_plane
def orange_range(pad_plane, rows):
	start_row = 70
	color = 210
	for i in range(rows):
		pad_plane[start_row:start_row+5, 8:17, 1] = color - 15
		pad_plane[start_row:start_row+5, 8:17, 2] = color
		start_row = start_row - 5 
		color = color - 15
	return pad_plane
def red_range(pad_plane, rows):
	start_row = 35
	color = 250
	for i in range(rows):
		pad_plane[start_row:start_row+5, 8:17, 0] = color
		pad_plane[start_row:start_row+5, 8:17, 1] = 50
		pad_plane[start_row:start_row+5, 8:17, 2] = 50
		start_row = start_row - 5 
		color = color - 15
	return pad_plane
def fill_energy_bar(image,proportion_filled):
	total_rows = math.floor(proportion_filled * 28) # Calculate how many rows should be filled
	# Fill the energy bar one row at a time
	if total_rows > 0:
		pad_plane = blue_range(image, rows=min(total_rows, 7))
	if total_rows > 7:
		pad_plane = yellow_range(image, rows=min(total_rows-7, 7))
	if total_rows > 14:
		pad_plane = orange_range(image, rows=min(total_rows-14, 7))
	if total_rows > 21:
		pad_plane = red_range(image, rows=min(total_rows-21, 7))
	return image

In [9]:
for i in tqdm(range(len(image_list)*duplication_factor)):
    image = Image.open(image_dir + image_list[i//duplication_factor]) # open image
    image = np.array(image)[:,:,:3] # convert to numpy array
    
    image = aug_padplane(image) # augment padplane
    if image is None:
        continue
    image = aug_trace(image) # augment trace
    image = aug_ebar(image) # augment energy bar
    
    # save image
    plt.imsave(output_dir + str(i%duplication_factor) + image_list[i//duplication_factor], image)

100%|██████████| 50/50 [00:04<00:00, 11.37it/s]


In [10]:
# create gif of pngs in aug_images
#os.system('convert -delay 10 -loop 0 ' + output_dir + '*.png ' + output_dir + '../aug_images.gif')