In [2]:
import os
import numpy as np
import rasterio
from pathlib import Path
from tqdm.auto import tqdm
from multiprocess import Pool,cpu_count

In [3]:
# set your tile size in pixels
tile_size_px = [1024,1024] #y,x
# pad out part tiles to the above dim?
pad = True
# what value would you like to pad with?
pad_value = 0
# set your overlap in pixels
tile_overlap_px = 256
# check above values
if tile_overlap_px > tile_size_px[0] or tile_overlap_px > tile_size_px[1]:
    print('You overlap must be less than the tile size')

In [4]:
# set the compression for your tiles
output_compression = 'LZW'#'JPEG'  #use JPEG for images and LZW for raw data 
# that folder is your input file wihtin?
geotiff_folder = '/data/Road extraction/old/test areas/11'
# what is the file name of the input file?
geotiff_file_name = 'roadsv3.2_resnet18_500px_4.6_LR0.0001_FT_4_finshed.tif'
# check file is correct
geotiff_path = os.path.join(geotiff_folder,geotiff_file_name)
print('Does the file exist?',os.path.isfile(geotiff_path))
#make tile folder based on input path
output_folder = os.path.splitext(geotiff_path)[0]
print('output folder =',output_folder)
Path(output_folder).mkdir(exist_ok=True)

Does the file exist? True
output folder = /data/Road extraction/old/test areas/11/roadsv3.2_resnet18_500px_4.6_LR0.0001_FT_4_finshed


In [5]:
# grab a copy of the input metadata 
with rasterio.open(geotiff_path) as open_raster:
    input_meta = open_raster.meta
input_meta

{'driver': 'GTiff',
 'dtype': 'uint8',
 'nodata': None,
 'width': 5895,
 'height': 4915,
 'count': 1,
 'crs': CRS.from_epsg(4326),
 'transform': Affine(8.975758272058824e-05, 0.0, 127.96875,
        0.0, -8.975758272058824e-05, -15.908203125)}

In [6]:
# make a list of tiles by looping over raster height and width
# this will store each tile details 
tiles = []
# keep track of the left edge of tile and tile count
left = 0
tile_count = 0
# loop untill left is beyond the input width
while left < input_meta['width']:
#     reset the tile top after each loop
    top=0
#     keep looping untill the tile top is beyond the input height
    while top < input_meta['height']:
#         increment tile counter
        tile_count += 1
#         building tile name and export path
        name = 'part_'+str(tile_count)+'_'+geotiff_file_name
        export_path = os.path.join(output_folder,name)
#         build dict with tile data
        tiles.append({'left':left,
                      'top':top,
                      'export_path':export_path})
#         increment tile top
        top += tile_size_px[0]-tile_overlap_px
#      increment tile left
    left += tile_size_px[1]-tile_overlap_px
    
print('Tiles to make =',len(tiles))

Tiles to make = 56


In [7]:
# this is your thread count, if the below process runs out or RAM manualy set this value to a lower number.
processes = cpu_count()
print('thread count =',processes)

thread count = 8


In [8]:
def pad_array(array):
    bands, y_shape, x_shape  = array.shape
    x_missing =  tile_size_px[0] - x_shape
    y_missing =  tile_size_px[1] - y_shape
    
    if x_missing or y_missing:
        array = np.pad(array,[(0, 0),(0, y_missing), (0, x_missing)])
    return array

In [9]:
# func to cut out list of tiles
# for tile in tiles:
def tile_cutter(tile):
#     open raster
    with rasterio.open(geotiff_path) as src:
#         setup up windows extent to extract
        win = rasterio.windows.Window(tile['left'], tile['top'], tile_size_px[1], tile_size_px[0])
#         read only window extent
        win_data = src.read(window=win)
        if pad:
            win_data = pad_array(win_data)
#         generate transform for window
        win_transform = src.window_transform(win)
#         setup metadata for tile based on input image
        meta = src.meta
        meta['compress'] = output_compression
        meta['driver'] = 'GTiff'
        meta['transform'] = win_transform
        meta['width'] = win_data.shape[2]
        meta['height'] = win_data.shape[1]
#         save out tile
        with rasterio.open(tile['export_path'], 'w', **meta) as dst:
            dst.write(win_data)


In [10]:
# # call the above func with multiprocessing, this won't on windows without passing all variables and libraries into the fun
with Pool(processes) as p:
    list(tqdm(p.imap(tile_cutter, tiles),total = len(tiles)))

  0%|          | 0/56 [00:00<?, ?it/s]

In [59]:
# slower but should work on all systems
# for tile in tqdm(tiles):
#     tile_cutter(tile)