In [None]:
# use this notebook to test the data augmentation functions, and to generate ECG photograph datasets
# (see bottommost cell)

In [3]:
# imports

from data_augmentation import *
from IPython.core.display_functions import display
import json
import math
import numpy as np
from PIL import Image
import os

In [None]:
# tips for finding sensible default parameters for the parameter distributions of get_random_ecg_photo:
#
# fix all sigmas to 0 except the parameter(s) to vary; downscale image for faster generation time
# go to desmos.com, input Gaussian/Beta PDF formula with parameters to see how likely a given parameter value will be
# for shadows, we can have either a bright shadow (camera flash) or a dark shadow (normal shadow); we use a U-shaped beta distribution to simulate this

In [None]:
bg_img = get_background_image(3)  # useful for fixing a specific background image

In [None]:
# experiment with different parameters here

img, mask_imgs, label, lead_pos, angle = get_random_ecg_photo(original_image_scaling_factor_mu=1.0,
                                                              original_image_scaling_factor_sd=0.0,
                                                              blur_factor_mu=0.1,
                                                              blur_factor_sd=0.05,
                                                              ecg_paper_scale_mu=1.02,
                                                              ecg_paper_scale_sd=0.02,
                                                              ecg_paper_y_skew_mu=1.0,
                                                              ecg_paper_y_skew_sd=0.08,
                                                              rotation_angle_mu=0.0,
                                                              rotation_angle_sd=5.0,
                                                              ecg_paper_relative_translation_x_mu=0.0,
                                                              ecg_paper_relative_translation_x_sd=0.05,
                                                              ecg_paper_relative_translation_y_mu=0.0,
                                                              ecg_paper_relative_translation_y_sd=0.05,
                                                              shadow_color_beta_1=0.3,
                                                              shadow_color_beta_2=0.3,
                                                              shadow_alpha_beta_1=8.0,
                                                              shadow_alpha_beta_2=25.0,
                                                              shadow_relative_start_point_mu=0.0,
                                                              shadow_relative_start_point_sd=0.1,
                                                              shadow_relative_end_point_mu=1.0,
                                                              shadow_relative_end_point_sd=0.1,
                                                              shadow_blur_factor_mu=50.0,
                                                              shadow_blur_factor_sd=3.0,
                                                              white_noise_p_beta_1=5.0,
                                                              white_noise_p_beta_2=5.0,
                                                              white_noise_sd_mu=20.0,
                                                              white_noise_sd_sd=5.0)
                                                              #background_image=bg_img)
display(img)

In [8]:
# create dataset directories, set save_dir

root_dir = os.path.join('datasets', 'ptb_v')
config_dir = root_dir
save_dir = os.path.join(root_dir, 'training')
for dir_path in [root_dir, config_dir, save_dir]:
    os.makedirs(dir_path, exist_ok=True)

In [9]:
# dataset generation

import datetime

num_samples_to_generate = 100

for sample_idx in range(num_samples_to_generate + 1):
    photo, mask_imgs, data, lead_pos, angle = get_random_ecg_photo()
    scp_codes = data['scp_codes']
    scp_code_str = '_'.join([f'{k}={v}' for k, v in scp_codes.items()])
    rec_date = datetime.datetime.strptime(data['recording_date'], '%Y-%m-%d %H:%M:%S')
    rec_timestamp = (rec_date - datetime.datetime(1970, 1, 1)).total_seconds()
    id_str = str(int(data['patient_id'])) + '_' + str(int(rec_timestamp))
    if len(scp_code_str) > 0:
        scp_code_str = '_' + scp_code_str
    path_no_ext = (('%0' + str(math.ceil(math.log10(num_samples_to_generate))) + 'd_' + id_str + scp_code_str)
                   % sample_idx)
    path_photo = os.path.join(save_dir, path_no_ext + '.jpg')
    photo.convert('RGB').save(path_photo)
    
    # now, we can merge the individual masks
    # use channel 3 (alpha channel) to detect presence of pixels

    curve_mask_arr = (np.array(mask_imgs[0])[:, :, 3] > 0).astype(np.uint8)
    thick_hor_lines_mask_arr = (np.array(mask_imgs[1])[:, :, 3] > 0).astype(np.uint8)
    thick_vert_lines_mask_arr = (np.array(mask_imgs[2])[:, :, 3] > 0).astype(np.uint8)

    mask_arr = ((thick_hor_lines_mask_arr * 85 * (1 - thick_vert_lines_mask_arr) + thick_vert_lines_mask_arr * 170) * 
                (1 - curve_mask_arr)) + curve_mask_arr * 255

    path_mask_img = os.path.join(save_dir, path_no_ext + '_mask.png')
    mask_img = Image.fromarray(mask_arr)
    mask_img.save(path_mask_img)

    img_width, img_height = photo.size
    with open(os.path.join(save_dir, path_no_ext + '_data.json'), 'w') as lead_pos_file:
        lead_pos['_img_path'] = path_photo
        lead_pos['_mask_img_path'] = path_mask_img
        lead_pos['_angle'] = angle
        lead_pos['_img_width'] = img_width
        lead_pos['_img_height'] = img_height
        lead_pos['_scp_codes'] = data['scp_codes']

        # set label
        
        mi_list = ['IMI', 'ASMI', 'ILMI', 'AMI', 'LMI', 'IPLMI', 'IPMI', 'PMI']
        if len(set(data['scp_codes'].keys()) & set(mi_list)) > 0:
            is_normal, is_non_mi_abnormality, is_mi = [False, False, True]
        elif len(set(data['scp_codes'].keys()) & set(['NORM'])) == 1 and data['scp_codes']['NORM'] >= 80.0:
            is_normal, is_non_mi_abnormality, is_mi = [True, False, False]
        else:
            is_normal, is_non_mi_abnormality, is_mi = [False, True, False]
        
        lead_pos['_is_normal'] = is_normal
        lead_pos['_is_non_mi_abnormality'] = is_non_mi_abnormality
        lead_pos['_is_mi'] = is_mi

        lead_pos_file.write(json.dumps(lead_pos))

# write dataset parameters

dataset_params = get_default_ecg_param_dict()
dataset_params['num_samples'] = num_samples_to_generate
with open(os.path.join(config_dir, 'dataset_hyperparams.json'), 'w') as f:
    f.write(json.dumps(dataset_params))

In [10]:
# generation of individual lead pictures

# lead bounding box coordinate distortion is drawn from a zero-mean Gaussian distribution
# the distribution's sigma parameter is determined relative to the width resp. height of the image

lead_bounding_box_distortion_sigma_x_relative_to_width = 0.01
lead_bounding_box_distortion_sigma_y_relative_to_height = 0.01

# lead bounding box rotation distortion is drawn from a zero-mean Gaussian distribution
# the distribution's sigma parameter is determined here

lead_bounding_box_rotation_distortion_sigma = 3


def get_new_point_pos(pt_x, pt_y, angle_deg, img_width, img_height):
    # credit to https://stackoverflow.com/a/51964802
    center_x = int(img_width / 2)
    center_y = int(img_height / 2)

    angle_rad = (angle_deg / 180.0) * math.pi
    new_px = center_x + int(float(pt_x - center_x) * math.cos(angle_rad)
                            + float(pt_y - center_y) * math.sin(angle_rad))
    new_py = center_y + int(-(float(pt_x - center_x) * math.sin(angle_rad))
                            + float(pt_y - center_y) * math.cos(angle_rad))
    return new_px, new_py

individual_lead_dir = os.path.join(save_dir, 'individual_leads')
os.makedirs(individual_lead_dir, exist_ok=True)

for root, dirs, files in os.walk(save_dir):
    for filename in files:
        if not filename.lower().endswith('.json'):
            continue
        with open(os.path.join(save_dir, filename), 'r+') as f:
            content = f.read()
            try:
                lead_pos = json.loads(content)
                _img_path = lead_pos['_img_path']
                _mask_img_path = lead_pos['_mask_img_path']
                angle = lead_pos['_angle']
                img_width = lead_pos['_img_width']
                img_height = lead_pos['_img_height']
            except Exception as e:
                print(f'Error with {filename}: {e}')
                continue

            photo_fn = os.path.basename(_img_path)
            photo_fn_dot_idx = photo_fn.rfind('.')
            photo_ext = photo_fn[photo_fn_dot_idx+1:]
            individual_lead_fn_prefix = os.path.join(individual_lead_dir, photo_fn[:photo_fn_dot_idx])

            mask_img_fn = os.path.basename(_mask_img_path)
            mask_img_fn_dot_idx = mask_img_fn.rfind('.')
            mask_img_ext = mask_img_fn[mask_img_fn_dot_idx+1:]

            img = Image.open(_img_path)
            mask_img = Image.open(_mask_img_path)    

            lead_bounding_box_distortion_sigma_x = lead_bounding_box_distortion_sigma_x_relative_to_width * img.size[0]
            lead_bounding_box_distortion_sigma_y = lead_bounding_box_distortion_sigma_y_relative_to_height * img.size[1]

            rot_kwargs = {'img_width': img_width, 'img_height': img_height}

            individual_lead_img_paths = {}
            individual_lead_mask_img_paths = {}

            for lead_name, pos in lead_pos.items():
                if lead_name.startswith('_'):
                    continue

                # apply rotation distortion per lead
                angle_distorted = angle + np.random.normal(0, lead_bounding_box_rotation_distortion_sigma)

                tl_x, tl_y, tr_x, tr_y, bl_x, bl_y, br_x, br_y =\
                    [round(p + np.random.normal(0,
                                                (lead_bounding_box_distortion_sigma_x if p_idx % 2 == 0
                                                    else lead_bounding_box_distortion_sigma_y)))
                                                for p_idx, p in enumerate(pos)]
                tl_x_rot, tl_y_rot = get_new_point_pos(tl_x, tl_y, angle_deg=-angle_distorted, **rot_kwargs)
                tr_x_rot, tr_y_rot = get_new_point_pos(tr_x, tr_y, angle_deg=-angle_distorted, **rot_kwargs)
                bl_x_rot, bl_y_rot = get_new_point_pos(bl_x, bl_y, angle_deg=-angle_distorted, **rot_kwargs)
                br_x_rot, br_y_rot = get_new_point_pos(br_x, br_y, angle_deg=-angle_distorted, **rot_kwargs)
                
                leftmost_x   = max(0, min(img_width - 1, min([tl_x_rot, tr_x_rot, bl_x_rot, br_x_rot])))
                topmost_y    = max(0, min(img_height - 1, min([tl_y_rot, tr_y_rot, bl_y_rot, br_y_rot])))
                
                rightmost_x  = max(0, min(img_width - 1, max([tl_x_rot, tr_x_rot, bl_x_rot, br_x_rot])))
                bottommost_y = max(0, min(img_height - 1, max([tl_y_rot, tr_y_rot, bl_y_rot, br_y_rot])))

                rot_img = img.rotate(-angle_distorted)
                rot_mask_img = mask_img.rotate(-angle_distorted)

                lead_img = rot_img.crop((leftmost_x, topmost_y, rightmost_x, bottommost_y))
                lead_mask_img = rot_mask_img.crop((leftmost_x, topmost_y, rightmost_x, bottommost_y))
                
                lead_img_path = f'{individual_lead_fn_prefix}_LEAD={lead_name}.{photo_ext}'
                lead_img.save(lead_img_path)

                lead_mask_img_path = f'{individual_lead_fn_prefix}_LEAD={lead_name}_mask.{mask_img_ext}'
                lead_mask_img.save(lead_mask_img_path)
                
                individual_lead_img_paths[lead_name] = lead_img_path
                individual_lead_mask_img_paths[lead_name] = lead_mask_img_path

            lead_pos['_lead_img_paths'] = individual_lead_img_paths
            lead_pos['_lead_mask_img_paths'] = individual_lead_mask_img_paths

            f.seek(0)
            f.write(json.dumps(lead_pos))
            f.truncate()

# save individual lead image generation parameters

lead_bounding_box_distortion_sigma_x_relative_to_width = 0.01
lead_bounding_box_distortion_sigma_y_relative_to_height = 0.01

# lead bounding box rotation distortion is drawn from a zero-mean Gaussian distribution
# the distribution's sigma parameter is determined here

lead_bounding_box_rotation_distortion_sigma = 3

individual_lead_generation_params =\
     {'lead_bounding_box_distortion_sigma_x_relative_to_width': lead_bounding_box_distortion_sigma_x_relative_to_width,
      'lead_bounding_box_distortion_sigma_y_relative_to_height': lead_bounding_box_distortion_sigma_y_relative_to_height,
      'lead_bounding_box_rotation_distortion_sigma': lead_bounding_box_rotation_distortion_sigma}

with open(os.path.join(config_dir, 'dataset_individual_lead_hyperparams.json'), 'w') as f:
    f.write(json.dumps(individual_lead_generation_params))