In [1]:
import numpy as np
import rasterio
import gc
from tqdm import tqdm

In [2]:
def load_raster(file_path):
    with rasterio.open(file_path) as dataset:
        raster_data = dataset.read()
        crs = dataset.crs.to_wkt()
        transform = dataset.transform
        bands, cols, rows = raster_data.shape
        raster_data = raster_data.reshape((rows,cols,bands))
    return raster_data,crs, transform

def sliceRaster(image,label, slice_size=256):
    rows, cols, bands = image.shape
    slices = []
    for i in range(0, rows, slice_size):
        for j in range(0, cols, slice_size):
            slice = image[i:i+slice_size, j:j+slice_size, :]
            label_slice = label[i:i+slice_size, j:j+slice_size]
            slices.append((slice, label_slice))
    gc.collect()
    return slices


def classifySlices(slice):
    if (np.array(slice) == 255).sum() >=1 or  (np.array(slice) == 1).sum() >=1:
            return 1
    return 0

def ChooseZeros(img_0_path, label_0_path, amount=10):
    np.random.seed(12)
    idx = np.arange(len(img_0_path))
    sample = np.random.choice(idx, amount, replace=False)
    img_0_paths = np.array(img_0_path)[sample]
    label_0_paths = np.array(label_0_path)[sample]

    return img_0_paths, label_0_paths


def saveSlices(slices, base_name, output_dir_train, output_dir_test,crs_wkt, transform, train=True):
    for i, (image,label) in enumerate(slices):
        class_label = classifySlices(label)
        if train:
            image_output = output_dir_train + str(class_label) + '\\image\\' + f"{base_name}_slice_{i}.tif" 
            label_output = output_dir_train + str(class_label) + '\\label\\' + f"{base_name}_slice_{i}.tif" 
        else:
            image_output = output_dir_test + str(class_label) + '\\image\\' + f"{base_name}_slice_{i}.tif" 
            label_output = output_dir_test + str(class_label) + '\\label\\' + f"{base_name}_slice_{i}.tif" 

        WriteRaster(image,crs_wkt,transform,image_output)
        WriteRaster(label,crs_wkt,transform,label_output)
    
def WriteRaster(array,crs_wkt,transform,output_file_path):
    rows, cols, bands = array.shape
    with rasterio.open(
        output_file_path, 'w', driver='GTiff', height=rows, width=cols,
        count=bands, dtype=array.dtype, crs=crs_wkt, transform=transform
    ) as dst:
        for b in range(bands):
            dst.write(array[ :, :,b], b + 1)

def ReadTextFile(file_path):
    with open(file_path, 'r') as file:
        data = file.read().splitlines()
    return data

In [3]:
def main(train=True, zero=False):
    image_paths = ReadTextFile("N:\\Tal\\from_1024_to_256\\text files\\2020\\image_1_training.txt")
    label_paths = ReadTextFile("N:\\Tal\\from_1024_to_256\\text files\\2020\\label_1_training.txt")
    if zero:
        image_paths, label_paths = ChooseZeros(image_paths, label_paths)
    output_dir_train = "N:\\Tal\\2020\\data256\\train\\"
    output_dir_test = "N:\\Tal\\2020\\data256\\test\\"
    slice_size = 256
    for i,img in enumerate(tqdm(image_paths, desc='Processing Raster Files')):
        img_data, crs, transform = load_raster(img)
        label_data, crs, transform = load_raster(label_paths[i])
        slices = sliceRaster(img_data, label_data,slice_size)
        base_name = img.split("/")[-1].split(".")[0]
        saveSlices(slices, base_name, output_dir_train, output_dir_test, crs, transform,train=train)
        del img_data,label_data, slices
        if i%100 == 0:
            gc.collect()

In [4]:
main()

Processing Raster Files: 100%|███████████████████████████████████████████████████| 3418/3418 [2:16:37<00:00,  2.40s/it]
