In [None]:
# Data generator is used to load a batch of these images(eg:16),masks and inputs.
# Keras not used bacause it only understands jpeg files.
# So a custom generator is used.

#from tifffile import imsave, imread
import os
import numpy as np


def load_img(img_dir, img_list):
    images=[]
    for i, image_name in enumerate(img_list):
        if (image_name.split('.')[1] == 'npy'): # since files are saved as npy

            image = np.load(img_dir+image_name) #np.load because npy. if images saved as other formats then np.csv and so on.

            images.append(image)
    images = np.array(images)

    return(images) # returns numpy array of images

# Up images are taken. Down images are loaded.
# it needs name of image directory, img list,mask directory, mask list and batch size(how many)

def imageLoader(img_dir, img_list, mask_dir, mask_list, batch_size):

    L = len(img_list) # tells us how many images are there

    #keras needs the generator infinite, so we will use while true
    while True:

        batch_start = 0
        batch_end = batch_size
        # while go from 0-16

        while batch_start < L:
            limit = min(batch_end, L)

            X = load_img(img_dir, img_list[batch_start:limit]) #load image
            Y = load_img(mask_dir, mask_list[batch_start:limit]) #load masks

            yield (X,Y) #a tuple with two numpy arrays with batch_size samples
            # output X and Y. Yielding X and Y means that values of X and Y are send above to get iteration of batches and then everytime different values of X and Y can be obtained.

           # counter
            batch_start += batch_size  # 16,32 and so on
            batch_end += batch_size

############################################

#Test the generator

from matplotlib import pyplot as plt
import random

# to train we need image and masks directories and lists
train_img_dir = "BraTS2020_TrainingData/input_data_128/train/images/"
train_mask_dir = "BraTS2020_TrainingData/input_data_128/train/masks/"
train_img_list=os.listdir(train_img_dir)
train_mask_list = os.listdir(train_mask_dir)

########################

# Defining data gen

# batch size is 2 because 3 similar volumes of 128*128*128 are embedded into 1. So 3 channels/bands.
batch_size = 2

# imageloader is the function
train_img_datagen = imageLoader(train_img_dir, train_img_list,
                                train_mask_dir, train_mask_list, batch_size)

###############################

# Testing data gen

# when data gen is called, the above X and Y are given i.e image and mask
#Verify generator.... In python 3 next() is renamed as __next__()

img, msk = train_img_datagen.__next__() #because yield is used up there, we go through batch 1,2,3... so, _next_ is used.
# o/p is 2 images and 2 masks

#Plotting to check if anything wrong
# randomly selecting image number. we only have 2.
# Randomly select a slice in each of these images and plot it.

img_num = random.randint(0,img.shape[0]-1)
test_img=img[img_num]
test_mask=msk[img_num]
test_mask=np.argmax(test_mask, axis=3) # converting categorical to integer.

n_slice=random.randint(0, test_mask.shape[2])
plt.figure(figsize=(12, 8))

plt.subplot(221)
plt.imshow(test_img[:,:,n_slice, 0], cmap='gray')
plt.title('Image flair')
plt.subplot(222)
plt.imshow(test_img[:,:,n_slice, 1], cmap='gray')
plt.title('Image t1ce')
plt.subplot(223)
plt.imshow(test_img[:,:,n_slice, 2], cmap='gray')
plt.title('Image t2')
plt.subplot(224)
plt.imshow(test_mask[:,:,n_slice])
plt.title('Mask')
plt.show()