In [2]:
from utils import load_plant_village
import tensorflow as tf
from PIL import Image
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import os
import random

In [3]:
images = load_plant_village()

In [4]:
images

{'Apple___Apple_scab': [<PIL.Image.Image image mode=RGB size=256x256>,
  <PIL.Image.Image image mode=RGB size=256x256>,
  <PIL.Image.Image image mode=RGB size=256x256>,
  <PIL.Image.Image image mode=RGB size=256x256>,
  <PIL.Image.Image image mode=RGB size=256x256>,
  <PIL.Image.Image image mode=RGB size=256x256>,
  <PIL.Image.Image image mode=RGB size=256x256>,
  <PIL.Image.Image image mode=RGB size=256x256>,
  <PIL.Image.Image image mode=RGB size=256x256>,
  <PIL.Image.Image image mode=RGB size=256x256>,
  <PIL.Image.Image image mode=RGB size=256x256>,
  <PIL.Image.Image image mode=RGB size=256x256>,
  <PIL.Image.Image image mode=RGB size=256x256>,
  <PIL.Image.Image image mode=RGB size=256x256>,
  <PIL.Image.Image image mode=RGB size=256x256>,
  <PIL.Image.Image image mode=RGB size=256x256>,
  <PIL.Image.Image image mode=RGB size=256x256>,
  <PIL.Image.Image image mode=RGB size=256x256>,
  <PIL.Image.Image image mode=RGB size=256x256>,
  <PIL.Image.Image image mode=RGB size=256x256>

In [17]:
all_images_df = pd.DataFrame(data=((k, img) for k, v in images.items() for img in v))


X_train, X_test, y_train, y_test = train_test_split(all_images_df[1], all_images_df[0], test_size=0.2, stratify=all_images_df[0])
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, stratify=y_train)

In [18]:
def get_crop_size(img, min_prop=0.7):
    def rand_prop(x):
        return x + (1. - x) * np.random.random()

    height, width, channels = img.shape

    if height > width:
        height_prop = rand_prop(min_prop)
        width_prop = rand_prop(height_prop)
    else:
        width_prop = rand_prop(min_prop)
        height_prop = rand_prop(width_prop)

    height = np.floor(height_prop * height).astype(int)
    width = np.floor(width_prop * width).astype(int)
    return height, width, channels

def random_augmentation(img):
    img = tf.keras.preprocessing.image.random_rotation(img, 20, row_axis=0, col_axis=1, channel_axis=2, fill_mode='reflect')
    img = tf.image.random_contrast(img, 0.8, 1.2)
    img = tf.image.random_brightness(img, 0.08)
    img = tf.image.random_hue(img, 0.025)
    img = tf.image.random_saturation(img, 0.85, 1.15)
    img = tf.image.random_jpeg_quality(img, 75, 95)
    img = tf.image.random_flip_up_down(img)
    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_crop(img, get_crop_size(img, 0.67))
    return img.numpy()


def get_augmented_image(img):
    return Image.fromarray(random_augmentation(np.array(img.convert('RGB')))).resize((224, 224))

In [20]:
for img_label in y_train.unique():
    counter = 0
    current_class_indexes = y_train[y_train == img_label].index
    current_class_images = X_train[current_class_indexes]
    dest = f'datasets/augmented/PlantVillage/train/{img_label}/'
    os.makedirs(os.path.dirname(dest), exist_ok=True)
    for img in current_class_images:
        img = Image.fromarray(np.array(img.convert('RGB'))).resize((224, 224))
        img.save(dest + str(counter) + '.jpg')
        counter += 1

    while counter < 1000:
        augmented_image = get_augmented_image(X_train[random.choice(current_class_images.index)])
        augmented_image.save(dest + str(counter) + '.jpg')
        counter += 1

for img_label in y_val.unique():
    counter = 0
    current_class_indexes = y_val[y_val == img_label].index
    current_class_images = X_val[current_class_indexes]
    dest = f'datasets/augmented/PlantVillage/val/{img_label}/'
    os.makedirs(os.path.dirname(dest), exist_ok=True)
    for img in current_class_images:
        img = Image.fromarray(np.array(img.convert('RGB'))).resize((224, 224))
        img.save(dest + str(counter) + '.jpg')
        counter += 1

for img_label in y_test.unique():
    counter = 0
    current_class_indexes = y_test[y_test == img_label].index
    current_class_images = X_test[current_class_indexes]
    dest = f'datasets/augmented/PlantVillage/test/{img_label}/'
    os.makedirs(os.path.dirname(dest), exist_ok=True)
    for img in current_class_images:
        img = Image.fromarray(np.array(img.convert('RGB'))).resize((224, 224))
        img.save(dest + str(counter) + '.jpg')
        counter += 1
