In [1]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
import numpy as np
import pandas as pd
import os
from PIL import Image
from tqdm import tqdm

In [2]:
def load_images(split_dir):
    images = {}

    for p in os.listdir(split_dir):
        p_dir = os.path.join(split_dir, p)
        if not os.path.isdir(p_dir):
            continue

        p_images = []
        for image_path in os.listdir(p_dir):
            p_images.append(os.path.join(p_dir, image_path))

        images[p] = p_images
    return images

def assign_labels(images_dict, labels_dict):
    images = []
    labels = []
    for p, p_images in images_dict.items():
        images.extend(p_images)
        labels.extend([labels_dict[p]] * len(p_images))
    return images, labels

In [3]:
data_dir = '../../data/dlmi-lymphocytosis-classification/'
output_dir = '../../data/dlmi-lymphocytosis-augmented-data/'

In [4]:
train_dir = data_dir + "trainset"
trainset_true_df = pd.read_csv(data_dir + "trainset/trainset_true.csv")
labels_dict = dict(zip(trainset_true_df["ID"], trainset_true_df["LABEL"]))

In [5]:
images_names_dict = load_images(train_dir)
images_names, images_labels = assign_labels(images_names_dict, labels_dict)
images_names, images_labels = np.array(images_names), np.array(images_labels)

In [6]:
images_names_0 = images_names[images_labels == 0]

In [7]:
images = []
for img_path in tqdm(images_names_0, desc=f'Loading'):
    img = Image.open(img_path)
    img_array = np.array(img)
    images.append(img_array)
images = np.array(images)

Loading: 100%|██████████| 2592/2592 [00:01<00:00, 1666.69it/s]


In [8]:
datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

augmented_images = []

for i in tqdm(range(images.shape[0]), desc='Augmentation'):
    img = images[i]
    img = img.reshape((1,) + img.shape)

    j = 0
    for batch in datagen.flow(img, batch_size=1, save_to_dir=output_dir, save_prefix='aug', save_format='jpg'):
        augmented_images.append(batch[0])

        j += 1
        if j >= 5:
            break


Augmentation: 100%|██████████| 2592/2592 [01:12<00:00, 35.56it/s]
