In [None]:
import os
import sys
import shutil
import glob
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import cv2
import Augmentor
from sklearn.model_selection import train_test_split
MODEL_DIR = os.getcwd()
ROOT_DIR = os.path.dirname(MODEL_DIR)
DATA_DIR = os.path.join(ROOT_DIR, "data")
sys.path.append(MODEL_DIR)

In [None]:
# Initialized data folder
TRAIN_DIR = os.path.join(DATA_DIR, "train")
TEST_DIR = os.path.join(DATA_DIR, "test")

shutil.rmtree(TRAIN_DIR)
shutil.rmtree(TEST_DIR)

os.makedirs(TRAIN_DIR, exist_ok=True)
os.makedirs(TEST_DIR, exist_ok=True)
os.makedirs(os.path.join(TRAIN_DIR, "img"), exist_ok=True)
os.makedirs(os.path.join(TRAIN_DIR, "mask"), exist_ok=True)
os.makedirs(os.path.join(TEST_DIR, "img"), exist_ok=True)
os.makedirs(os.path.join(TEST_DIR, "mask"), exist_ok=True)

In [None]:
# resize
MASK_DIR = os.path.join(DATA_DIR, "JSRT", "scr", "masks")
MASK_HEART_DIR = os.path.join(MASK_DIR, "heart")
MASK_HEART_RESIZE_DIR = os.path.join(MASK_HEART_DIR, "heart_resize")

os.makedirs(MASK_HEART_DIR, exist_ok=True)
heart_mask = [os.path.join(MASK_HEART_DIR, gif) for gif in os.listdir(MASK_HEART_DIR)]
heart_mask_resize = [path.replace("heart", "heart_resized") for path in heart_mask]

for src_img_path, dst_img_path in zip(heart_mask, heart_mask_resize):
    img = cv2.resize(plt.imread(src_img_path) , dsize=(512, 512), interpolation=cv2.INTER_AREA)
    plt.imsave(dst_img_path, img, cmap="gray")

In [None]:
# collect data
NODULE_PNG_DIR = os.path.join(DATA_DIR, "JSRT", "nodules", "png")
NON_NODULE_PNG_DIR = os.path.join(DATA_DIR, "JSRT", "non_nodules", "png")
nodules_pngs = [os.path.join(NODULE_PNG_DIR, png) for png in os.listdir(NODULE_PNG_DIR)]
non_nodules_pngs = [os.path.join(NON_NODULE_PNG_DIR, png) for png in os.listdir(NON_NODULE_PNG_DIR)]
entire_img = nodules_pngs + non_nodules_pngs
heart_mask_resize


samples = [os.path.basename(path).replace(".gif", "") for path in heart_mask_resize]
train_samples, test_samples = train_test_split(samples)

for train_sample in train_samples:
    for img in entire_img:
        if train_sample not in img:
            continue

        shutil.copy(img, os.path.join(TRAIN_DIR, "img", f"{train_sample}.png"))
    for mask in heart_mask_resize:
        if train_sample not in mask:
            continue

        shutil.copy(mask, os.path.join(TRAIN_DIR, "mask", f"{train_sample}.png"))

In [None]:
for test_sample in test_samples:
    for img in entire_img:
        if test_sample not in img:
            continue

        shutil.copy(img, os.path.join(TEST_DIR, "img", f"{test_sample}.png"))
    for mask in heart_mask_resize:
        if test_sample not in mask:
            continue

        shutil.copy(mask, os.path.join(TEST_DIR, "mask", f"{test_sample}.png"))

In [None]:
# Augmentation

# option 1
ground_truth_images = glob.glob(os.path.join(TRAIN_DIR, "img", "*"))
segmentation_mask_images = glob.glob(os.path.join(TRAIN_DIR, "mask", "*"))
collated_images_and_masks = list(zip(ground_truth_images, segmentation_mask_images))

images = [[np.asarray(Image.open(x)), np.asarray(Image.open(y))] for x, y in collated_images_and_masks]

AugmentationPipeline = Augmentor.DataPipeline(images)
AugmentationPipeline.random_distortion(probability=1, grid_width=5, grid_height=5, magnitude=8)
AugmentationPipeline.gaussian_distortion(probability=1, grid_width=5, grid_height=5, magnitude=8, corner='bell', method='in ')
augmented_images = AugmentationPipeline.sample(1000)

cnt = 1
for img, mask in augmented_images:
    imag = Image.fromarray(img)
    imag.save(os.path.join(TRAIN_DIR, "img", f"aug_{cnt}.png"))
    mask = Image.fromarray(mask)
    mask.save(os.path.join(TRAIN_DIR, "mask", f"aug_{cnt}.png"))
    cnt += 1

In [None]:
# option 2
AugmentationPipeline = Augmentor.DataPipeline(images)
AugmentationPipeline.random_distortion(probability=1, grid_width=7, grid_height=7, magnitude=10)
AugmentationPipeline.gaussian_distortion(probability=1, grid_width=7, grid_height=7, magnitude=10, corner='bell', method='in ')
augmented_images = AugmentationPipeline.sample(500)

for img, mask in augmented_images:
    imag = Image.fromarray(img)
    imag.save(os.path.join(TRAIN_DIR, "img", f"aug_{cnt}.png"))
    mask = Image.fromarray(mask)
    mask.save(os.path.join(TRAIN_DIR, "mask", f"aug_{cnt}.png"))
    cnt += 1