In [332]:
from functools import reduce

import os
import jax
import jax.numpy as jnp
from imax import transforms
import matplotlib.pyplot as plt

from PIL import Image
from jmd_imagescraper.core import *
from jmd_imagescraper.imagecleaner import *
from jax.random import PRNGKey as jkey
from chex import Array, Shape, PRNGKey

## Dataset Creation
### Download galaxies from *DuckDuckGo*

In [None]:
# root = "datasets"
# duckduckgo_search(root, "galaxies", "spiral galaxy", max_results=250)
# display_image_cleaner(root)

### Load galaxies from drive

In [8]:
PARENT_DIR = "datasets/galaxies_raw"

files = os.listdir("datasets/galaxies_raw")

load_galaxy = lambda file: jnp.array(Image.open(os.path.join(PARENT_DIR, file)).resize((64, 64)))

images = jnp.array(list(map(load_galaxy, files)))
images.shape

(217, 64, 64, 3)

### Get rid of repetitions

In [9]:
no_images = images.shape[0]

repetitions_dict = dict()
for original_idx in range(no_images - 1):
    for susspect_idx in range(original_idx + 1, no_images):
        if jnp.array_equal(images[susspect_idx], images[original_idx]):
            if original_idx in repetitions_dict.keys():
                repetitions_dict[original_idx].append(susspect_idx)
            else:
                repetitions_dict[original_idx] = [susspect_idx]

In [297]:
repetitions_set = set(reduce(lambda x, y: x + y, repetitions_dict.values(), []))
unique_indeces = set(range(no_images)).difference(repetitions_set)
unique_images = images[jnp.array(list(unique_indeces))]

### Augment data

In [261]:
def augmentat_img(img, key: PRNGKey):

    key_1, key_2, key_3, key_4, key_5, key_6 = jax.random.split(key, 6)

    rot = transforms.rotate(
        rad=jax.random.uniform(key_1, minval=-jnp.pi / 8, maxval=jnp.pi / 8)
    )
    flip = transforms.flip(
        jax.random.uniform(key_2) < 0.5,
        jax.random.uniform(key_3) < 0.5
    )
    scale = transforms.scale(
        jax.random.uniform(key_4, minval=1.0, maxval=1.5),
        jax.random.uniform(key_4, minval=1.0, maxval=1.5)
    )
    translate = transforms.translate(
        jax.random.uniform(key_5, minval=-8.0, maxval=8.0),
        jax.random.uniform(key_6, minval=-8.0, maxval=8.0)
    )

    return transforms.apply_transform(img, flip @ rot @ translate @ scale, mask_value=-1)

In [300]:
def augment_dataset(key: PRNGKey, dataset: Array, rate: Scalar, dir_path: str, label: str, ext: str = "png"):
    
    reps = int(rate // 1)
    prob = rate % 1

    key, uniform_key = jax.random.split(key)
    augment_keys = jax.random.split(key, reps + 1)

    counter = 1
    probs = jax.random.uniform(uniform_key, shape=(dataset.shape[0],))
    for img_idx, img in enumerate(dataset):

        plt.imsave(
            os.path.join(dir_path, f"{label}_{counter}.{ext}"),
            img
        )
        counter += 1

        for i in range(reps):
            plt.imsave(
                os.path.join(dir_path, f"{label}_{counter}.{ext}"),
                augmentat_img(img, augment_keys[i])
            )
            counter += 1
        
        if probs[img_idx] < prob:
            plt.imsave(
                os.path.join(dir_path, f"{label}_{counter}.{ext}"),
                augmentat_img(img, augment_keys[-1])
            )
            counter += 1