In [None]:
from osgeo import gdal
import numpy as np
from tensorflow.keras.utils import Sequence
import albumentations as A

In [None]:
def scene_folder_list(catagory, pathrow_list, data_source_folder='/workspace/_libs/dl_library'):
    """
This function is designed to return the path of chips folders the user requests
    :param catagoryr: the catagory of landcover. We have four catagories: Mangrove, Water, Wetland, and Pond
    :param pathrow_list: a list of pathrow IDs.eg: ['128051','011060']
    :param data_source_folder: the data source folder of chips. On the blade, it is the default path. On Sandy, it should be r'//sandy.local/projects/Moore_Automation/DL_Training_Library'
    :return: the list of paths of chips folders.
    """
    source_folder = os.path.join(data_source_folder,catagory)
    res_folder_list = []
    total_num = len(pathrow_list)
    success_num = 0
    for pathrow in pathrow_list:
        for folder in os.listdir(source_folder):
            if folder.endswith(pathrow):
                res_folder_list.append(os.path.join(source_folder,folder))
                success_num+=1
    print(str(success_num),'scenes exist for',str(total_num),'required scenes')
    return res_folder_list

In [None]:
def read_bands(path):
    """
    This function is to read bands from a given chip
    :param path: the file path of a chip
    :return: numpy arrays of bands
    """
    band_array = []
    file = gdal.Open(path)
    res_array = np.zeros((256, 256, 6))
    for num in range(6):
        array = np.array(file.GetRasterBand(num+1).ReadAsArray())
        res_array[:,:,num] = array
    band_array.append(res_array)
    del file
    return np.array(band_array).squeeze()

In [None]:
def read_labels(path):
    """
    This function is to read labels from a given chip
    :param path: the file path of a chip
    :return: numpy arrays of labels
    """
    label_array = []
    file = gdal.Open(path)
    array = np.array(file.GetRasterBand(1).ReadAsArray())
    label_array.append(array)
    del file

    return np.array(label_array).squeeze()

In [None]:
"""
Data Augmentation using albumentations library
"""
transform = A.Compose([
    A.RandomRotate90(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
])

In [None]:
class DataGenerator(Sequence):
  def __init__(self, data_csv_path, batch_size,num_bands=6, transform = None):
    _data = pd.read_csv(data_csv_path)
    self.path_list = _data['bands'].tolist()
    self.labels_list = _data['labels'].tolist()
    self.batch_size = batch_size
    self.transform = transform
    self.num_bands = num_bands
    

  def __len__(self):
    return int(np.floor(len(self.path_list) / self.batch_size))

  def __getitem__(self,index):
    # print('working on batch ',index)
    batch_bands_paths = self.path_list[index*self.batch_size:(index+1)*self.batch_size]
    batch_labels_paths = self.labels_list[index*self.batch_size:(index+1)*self.batch_size]
    

    x = np.empty((self.batch_size, 256, 256, self.num_bands), dtype=np.float32)
    l = np.empty((self.batch_size, 256, 256), dtype=np.float32)

    if self.transform is None:
        for idx, data in enumerate(zip(batch_bands_paths,batch_labels_paths)):
            b_path = data[0]
            l_path = data[1]
            x[idx] = read_bands(b_path)
            l[idx] = read_labels(l_path)
    else:
        for idx, data in enumerate(zip(batch_bands_paths,batch_labels_paths)):
            b_path = data[0]
            l_path = data[1]
            image = read_bands(b_path)
            masks = read_labels(l_path)

            transformed = self.transform(image=image, masks=masks)
            x[idx] = transformed['image']
            l[idx] = transformed['masks']


    return x,l