In [1]:
import jax.numpy as jnp
from jax import random
import jax
import numpy as np

In [2]:
from keras.preprocessing.image import ImageDataGenerator

# Runtime data augmentation
def get_augmented_data_generator(
    X_train,
    Y_train,
    seed=0,
    data_gen_args=dict(
        rotation_range=10.0,
        # width_shift_range=0.02,
        height_shift_range=0.02,
        shear_range=5,
        # zoom_range=0.3,
        horizontal_flip=True,
        vertical_flip=False,
        fill_mode="constant",
    ),
):
    # Train data, provide the same seed and keyword arguments to the fit and flow methods
    X_datagen = ImageDataGenerator(**data_gen_args)
    Y_datagen = ImageDataGenerator(**data_gen_args)
    X_datagen.fit(X_train, augment=True, seed=seed)
    Y_datagen.fit(Y_train, augment=True, seed=seed)
    X_train_augmented = X_datagen.flow(
        X_train, batch_size=1, shuffle=True, seed=seed
    )
    Y_train_augmented = Y_datagen.flow(
        Y_train, batch_size=1, shuffle=True, seed=seed
    )

    train_generator = zip(X_train_augmented, Y_train_augmented)

    return train_generator

2022-05-31 20:11:21.827929: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory


In [3]:
class UnetDataGenerator:
    def __init__(self, images, masks):
        self.keras_generator = get_augmented_data_generator(images, masks)

    def get_batch(self):
        image, mask = next(self.keras_generator)
        image = jax.device_put(image)
        mask = jax.device_put(mask)
        return {"image": image, "mask": mask}

    def get_batch_numpy(self):
        return next(self.keras_generator)

In [4]:
from PIL import Image
import glob

masks = glob.glob("../data/isbi2015/train/label/*.png")
orgs = glob.glob("../data/isbi2015/train/image/*.png")
imgs_list = []
masks_list = []
for image, mask in zip(orgs, masks):
    imgs_list.append(np.array(Image.open(image).resize((512,512))))
    masks_list.append(np.array(Image.open(mask).resize((512,512))))
imgs_np = np.asarray(imgs_list).reshape((30,512,512,1))
masks_np = np.asarray(masks_list).reshape((30,512,512,1))
dataset = {"images": imgs_np, "masks": masks_np}

In [5]:
unet_datagen = UnetDataGenerator(imgs_np, masks_np)
unet_datagen.get_batch()["image"]



DeviceArray([[[[0.],
               [0.],
               [0.],
               ...,
               [0.],
               [0.],
               [0.]],

              [[0.],
               [0.],
               [0.],
               ...,
               [0.],
               [0.],
               [0.]],

              [[0.],
               [0.],
               [0.],
               ...,
               [0.],
               [0.],
               [0.]],

              ...,

              [[0.],
               [0.],
               [0.],
               ...,
               [0.],
               [0.],
               [0.]],

              [[0.],
               [0.],
               [0.],
               ...,
               [0.],
               [0.],
               [0.]],

              [[0.],
               [0.],
               [0.],
               ...,
               [0.],
               [0.],
               [0.]]]], dtype=float32)

In [6]:
import keras_unet_utils

img, mask = unet_datagen.get_batch_numpy()
keras_unet_utils.plot_imgs(img, mask, nm_img_to_plot=2, figsize=6)

ModuleNotFoundError: No module named 'keras_unet_utils'