In [1]:
import rasterio
from rasterio.windows import Window
import numpy as np
import cv2
import os
from tqdm import tqdm

In [12]:
def save_tile(raster, window, output_path):
    tile = raster.read(window=window)
    transform = raster.window_transform(window)
    with rasterio.open(
        output_path,
        'w',
        driver='GTiff',
        height=window.height,
        width=window.width,
        count=raster.count,
        dtype=raster.dtypes[0],
        crs=raster.crs,
        transform=transform,
    ) as dst:
        dst.write(tile)

def crop_image(src, x, y, crop_size):
    window = rasterio.windows.Window(x - crop_size // 2, y - crop_size // 2, crop_size, crop_size)
    return window

def process_positive_files_with_overlap(ground_truth_path, 
                                        rgb_paths, 
                                        stream_order_path, 
                                        output_dir, 
                                        crop_size=128, 
                                        overlap_rate=0.5,
                                        tile_number=0):
    os.makedirs(output_dir, exist_ok=True)

    gt_dir = os.path.join(output_dir, "ground_truth")
    stream_dir = os.path.join(output_dir, "dem")
    rgb_dir = os.path.join(output_dir, "rgb_images")
    os.makedirs(gt_dir, exist_ok=True)
    os.makedirs(stream_dir, exist_ok=True)
    os.makedirs(rgb_dir, exist_ok=True)

    with rasterio.open(ground_truth_path) as gt_src, rasterio.open(stream_order_path) as stream_src:
        gt_data = gt_src.read(1)
        if gt_data.min() < 0:
            gt_data = -gt_data
        y_indices, x_indices = np.where(gt_data > 0)

        rgb_srcs = [rasterio.open(path) for path in rgb_paths]
        cropped_regions = []
        overlap_th = crop_size * overlap_rate
        all_indices = zip(x_indices, y_indices)

        for x, y in tqdm(all_indices):
            overlap = any(abs(prev_x - x) < overlap_th and abs(prev_y - y) < overlap_th for prev_x, prev_y in cropped_regions)
            if overlap:
                continue

            cropped_regions.append((x, y))

            for size in [crop_size, crop_size * 2]:
                window = crop_image(stream_src, x, y, size)
                save_tile(stream_src, window, os.path.join(stream_dir, f'dem_tile_{size}_{tile_number}.tif'))
                window = crop_image(gt_src, x, y, size)
                save_tile(gt_src, window, os.path.join(gt_dir, f'ground_truth_tile_{size}_{tile_number}.tif'))

                for i, rgb_src in enumerate(rgb_srcs):
                    window = crop_image(rgb_src, x, y, size)
                    save_tile(rgb_src, window, os.path.join(rgb_dir, f'rgb_{i}_tile_{size}_{tile_number}.tif'))

            tile_number += 1

        for src in rgb_srcs:
            src.close()

    return tile_number

def process_files_with_negative_check(ground_truth_path, 
                                      rgb_paths, 
                                      stream_order_path, 
                                      output_dir, 
                                      crop_size=128, 
                                      overlap_rate=0.5, 
                                      buffer_size=50,
                                      tile_number=0):
    os.makedirs(output_dir, exist_ok=True)

    gt_dir = os.path.join(output_dir, "ground_truth")
    stream_dir = os.path.join(output_dir, "dem")
    rgb_dir = os.path.join(output_dir, "rgb_images")
    os.makedirs(gt_dir, exist_ok=True)
    os.makedirs(stream_dir, exist_ok=True)
    os.makedirs(rgb_dir, exist_ok=True)

    with rasterio.open(ground_truth_path) as gt_src, rasterio.open(stream_order_path) as stream_src:
        gt_data = gt_src.read(1)
        if gt_data.min() < 0:
            gt_data = -gt_data
        positive_points = np.argwhere(gt_data > 0)
        cropped_regions = []

        for px, py in tqdm(positive_points):
            for dx in range(-buffer_size, buffer_size + 1, crop_size):
                for dy in range(-buffer_size, buffer_size + 1, crop_size):
                    x, y = px + dx, py + dy
                    if not (0 <= x < gt_src.width and 0 <= y < gt_src.height):
                        continue

                    for size in [crop_size, crop_size * 2]:
                        window = crop_image(gt_src, x, y, size)
                        overlap = any(
                            np.sqrt((prev_x - window.col_off)**2 + (prev_y - window.row_off)**2) < overlap_rate * size
                            for prev_x, prev_y in cropped_regions
                        )
                        if overlap:
                            continue

                        cropped_gt = gt_src.read(1, window=window)
                        if np.any(cropped_gt > 0):
                            continue

                        cropped_regions.append((window.col_off, window.row_off))
                        save_tile(gt_src, window, os.path.join(gt_dir, f'negative_ground_truth_tile_{size}_{tile_number}.tif'))
                        save_tile(stream_src, window, os.path.join(stream_dir, f'dem_tile_{size}_{tile_number}.tif'))

                        for i, rgb_path in enumerate(rgb_paths):
                            with rasterio.open(rgb_path) as rgb_src:
                                save_tile(rgb_src, window, os.path.join(rgb_dir, f'rgb_{i}_tile_{size}_{tile_number}.tif'))

                        tile_number += 1

    return tile_number


In [13]:
# root_paths = ['/home/macula/SMATousi/Gullies/ground_truth/organized_data/MO_Downloaded_HUCs/',
#               '/home/macula/SMATousi/Gullies/ground_truth/organized_data/OH_Downloaded_HUCs/']

root_paths = ['/home/macula/SMATousi/Gullies/ground_truth/organized_data/MO+IA_downloaded_Test_HUCs/']

pos_output_dir = '/home/macula/SMATousi/Gullies/ground_truth/organized_data/test_data_with_context/pos/'
neg_output_dir = '/home/macula/SMATousi/Gullies/ground_truth/organized_data/test_data_with_context/neg/'

starting_pos_tile_number = 0
starting_neg_tile_number = 0

last_neg_tile_number = 0
last_pos_tile_number = 0

for root_path in root_paths:

    all_hucs = os.listdir(root_path)

    for huc_name in all_hucs:

        if huc_name.endswith("done"):

            print("Starting with HUC: ", huc_name)

            huc_path = os.path.join(root_path, huc_name)
            data_path = os.path.join(huc_path, "data")

            GT_path = os.path.join(data_path, "gt/rasterized_gt.tif")

            rgb_paths = [os.path.join(data_path,'merged/tile_10__merged.tif'), 
                         os.path.join(data_path,'merged/tile_12__merged.tif'), 
                         os.path.join(data_path,'merged/tile_14__merged.tif'), 
                         os.path.join(data_path,'merged/tile_16__merged.tif'),
                         os.path.join(data_path,'merged/tile_18__merged.tif'), 
                         os.path.join(data_path,'merged/tile_20__merged.tif')]

            dem_path = os.path.join(data_path,'merged/dem_tile__merged.tif')

            starting_pos_tile_number = last_neg_tile_number
#             starting_neg_tile_number = last_neg_tile_number

#             try:

            last_pos_tile_number = process_positive_files_with_overlap(GT_path, 
                                                               rgb_paths, 
                                                               dem_path, 
                                                               pos_output_dir, 
                                                               crop_size=128, 
                                                               overlap_rate=0.25,
                                                               tile_number=starting_pos_tile_number)

            last_neg_tile_number = process_files_with_negative_check(GT_path, 
                                                              rgb_paths, 
                                                              dem_path, 
                                                              neg_output_dir, 
                                                              crop_size=128, 
                                                              overlap_rate=0.25,  
                                                              buffer_size=10,
                                                              tile_number=last_pos_tile_number)
#             except:
#                 print("Error in HUC: ", huc_name)
#                 continue

    #     break

Starting with HUC:  HUC_071100080401-done


2136it [04:43,  7.52it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 2136/2136 [02:59<00:00, 11.89it/s]


Starting with HUC:  HUC_070801030408-done


3575it [06:57,  8.57it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 3575/3575 [03:36<00:00, 16.52it/s]


Starting with HUC:  HUC_071100060101-done


3422it [05:52,  9.72it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 3422/3422 [03:11<00:00, 17.86it/s]


Starting with HUC:  HUC_070802050807-done


6340it [11:13,  9.41it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 6340/6340 [04:28<00:00, 23.60it/s]


Starting with HUC:  HUC_070801050302-done


4591it [08:16,  9.25it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 4591/4591 [05:51<00:00, 13.05it/s]


In [15]:
7536/12

628.0