In [48]:
import logging
import os
import sys
import torch
import cv2
import geopandas as gpd
import numpy as np
import rasterio
import argparse
from tqdm import tqdm
from shapely.geometry import Polygon
from rasterio.windows import Window
from rasterio import windows

from utils_inference import sliding_windows

In [49]:
TILE_FOR_INFER_PATH = 'test_tiles/20220426_115850_ssc6_u0001_visual_clip.tif'
PATH_FOR_SAVE = 'test_tiles'
PATH_TO_MODEL = 'model_enginery_detect.pth'
THRESHOLD = 0.45
BANDS_ORDER = (3,2,1)

In [50]:
def create_reordered_image(raster_path, dst_raster_path, bands_order, step=6000, std_norm=2):
    
    
    src = rasterio.open(raster_path)
    w, h = src.meta['width'], src.meta['height']

    if w < step and h < step:
        step = np.min([w, h])

    whole_rem_w = divmod(w, step)
    whole_rem_h = divmod(h, step)

    all_steps_h = [(0, i * step, 0, step) for i in range(whole_rem_h[0])]

    if whole_rem_h[1] != 0:
        all_steps_h = all_steps_h + [(0, all_steps_h[-1][1] + step, 0, whole_rem_h[1])]

    all_steps = []
    for h_step in all_steps_h:
        all_steps = all_steps + [(i * step, h_step[1], step, h_step[-1]) for i in range(whole_rem_w[0])]
        if whole_rem_w[1] != 0:
            all_steps = all_steps + [(all_steps[-1][0] + step, h_step[1], whole_rem_w[1], h_step[-1])]

    pixels_sum = np.sum([np.sum(src.read(bands_order, window=Window(*i)), axis=(1, 2)) for i in all_steps], axis=0)

    means_channels = (pixels_sum / (w * h)).reshape((3, 1, 1))

    squared_deviation = np.sum(
        [np.sum((src.read(bands_order, window=Window(*i)) - means_channels) ** 2, axis=(1, 2)) for i in all_steps],
        axis=0)

    std = (squared_deviation / (w * h)) ** 0.5

    max_ = (means_channels.reshape(3, -1) + std_norm * std.reshape(3, -1)).reshape(3, 1, 1)

    profile = src.profile
    profile['dtype'] = 'uint8'
    profile['count'] = 3
    profile['nodata'] = 0
    with rasterio.open(
            dst_raster_path, 'w', **profile
    ) as dst:
        for i in all_steps:

            window_normalize = src.read(bands_order, window=Window(*i))
            mask_none = np.where(np.sum(window_normalize, axis=0) == 0, True, False)
            window_normalize = np.clip((window_normalize / max_) * 255, 1, 255).astype(rasterio.uint8)
            for channel in range(3):
                window_normalize[channel][mask_none] = 0

            dst.write(window_normalize, window=Window(*i))

    return dst_raster_path

In [51]:
def sliding_windows(path_or_profile, inner_win_size, middle_win_size, outer_win_size):
    """
    creates three nested sliding windows using rasterio meta
    """

    if isinstance(path_or_profile, str):
        with rasterio.open(path_or_profile) as src:

            src_profile = src.profile
    else:
        src_profile = path_or_profile

    new_profile = src_profile.copy()
    width = new_profile['width']
    height = new_profile['height']

    x_start = 0
    y_start = 0
    x_stop = min(inner_win_size, width)
    y_stop = min(inner_win_size, height)
    inner_wins = []
    outer_wins = []
    middle_wins = []
    offset_mid = int((middle_win_size - min(inner_win_size, width, height)) / 2)
    offset_outer = int((outer_win_size - min(inner_win_size, width, height)) / 2)

    for r in range(0, height, inner_win_size):
        for c in range(0, width, inner_win_size):
            if y_start < height and x_start < width:
                inner_wins.append(windows.Window.from_slices((y_start, y_stop), (x_start, x_stop)))

                middle_wins.append(windows.Window.from_slices((max(y_start - offset_mid, 0),
                                                               min(y_stop + offset_mid, height)),
                                                              (max(x_start - offset_mid, 0),
                                                               min(x_stop + offset_mid, width))))

                outer_wins.append(windows.Window.from_slices((max(y_start - offset_outer, 0),
                                                              min(y_stop + offset_outer, height)),
                                                             (max(x_start - offset_outer, 0),
                                                              min(x_stop + offset_outer, width))))
            x_start = min(x_stop, width)
            x_stop = min(x_start + inner_win_size, width)
        y_start = min(y_stop, height)
        y_stop = min(y_start + inner_win_size, height)
        x_start = 0
        x_stop = min(x_start + inner_win_size, width)

    new_profile.update(tiled=True,
                       compress='lzw')

    return inner_wins, middle_wins, outer_wins, new_profile

In [61]:
class IndexPredictor:
    def __init__(self,
                 raster_path,
                 output_path,
                 weights_path,
                 input_width,                 
                 thresh,
                 win_size):

        self.raster_path = raster_path
        with rasterio.open(self.raster_path) as src:
            self.crs = src.meta['crs'].to_string()
            self.transform = src.meta['transform']
        self.weights_path = weights_path
        self.output_path = output_path
        self.thresh = thresh
        self.model = torch.load(weights_path)
        self.model.eval()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.input_height = input_width
        self.input_width = input_width
        self.raster_out = os.path.join(self.output_path, 'autoindexing.tif')
        self.dst_raster_path = os.path.join(self.output_path, 'visual_tile.tif')
        self.win_size = win_size
        self.middle_win_size = int(self.win_size / self.transform[0])  ## 30 m window
        self.inner_win_size = self.middle_win_size - int(self.middle_win_size * 0.4)
        self.outer_win_size = 0

    def _pad_win(self, img):
        shape = img.shape
        pad_height = 0
        pad_width = 0
        if shape[1] < self.middle_win_size:
            pad_height = self.middle_win_size - shape[1]
        if shape[2] < self.middle_win_size:
            pad_width = self.middle_win_size - shape[2]
        if pad_height != 0 or pad_width != 0:
            img = np.pad(img, pad_width=((0, 0), (0, pad_height), (0, pad_width)), constant_values=0, mode='constant')
        return img

    def _process_batch(self, src_profile, inner_wins, middle_wins, src, offset_mid):
        batch_size = 10
        with rasterio.open(self.raster_out, 'w', **src_profile) as dst:
            with tqdm(total=len(inner_wins)) as pbar:
                for i in range(0, len(inner_wins), batch_size):
                    batch_len = min(batch_size, len(inner_wins) - i)
                    iw = [inner_wins[j] for j in range(i, i + batch_len)]
                    mw = [middle_wins[j] for j in range(i, i + batch_len)]
                    image = [src.read((3,2,1), window=mw[j]) for j in range(batch_len)]
                    image = [self._pad_win(j) for j in image]
                    image = [j.transpose(1, 2, 0) for j in image]
                    img_size = [j.shape[:-1] for j in image]
                    image = [cv2.resize(j, (self.input_height, self.input_width)) for j in image]

                    x = np.stack(image)
                    x = x / 127.5
                    x -= 1.0
                    batch_x = x
                    batch_x = np.moveaxis(batch_x, [0, 1, 2, 3], [0, 2, 3, 1])
                    batch_x = torch.from_numpy(batch_x)
                    batch_x = batch_x.float()

                    preds = self.model(batch_x)
                    for j in range(batch_len):


                        prediction = preds[j] > self.thresh
                        prediction = prediction.cpu().numpy()

                        pred = cv2.resize((prediction[0]*255).astype(np.uint8), (img_size[j][1], img_size[j][0]))

                        pred = pred[
                               int(min(offset_mid, iw[j].row_off)):int(min(offset_mid, iw[j].row_off)) + iw[j].height,
                               int(min(offset_mid, iw[j].col_off)):int(min(offset_mid, iw[j].col_off)) + iw[j].width]

                        dst.write(pred,1, window=iw[j])
                        pbar.update()

    def segment_grove(self):
        # predicts indices from a raster using sliding windows
        raster_in = self.raster_path
        logging.info("segmenting grove")
 
                     
        dst_raster_path = create_reordered_image(raster_in, self.dst_raster_path, BANDS_ORDER, step=6000, std_norm=2)
        with rasterio.open(dst_raster_path) as src:
            src_profile = src.profile
            src_profile.update(count=1,
                               nodata=None,
                               dtype='uint8')
        
        
        
            offset_mid = int((self.middle_win_size - self.inner_win_size) / 2)

            inner_wins, middle_wins, outer_wins, new_profile = sliding_windows(src_profile,
                                                                               self.inner_win_size,
                                                                               self.middle_win_size,
                                                                               self.outer_win_size)

            self._process_batch(src_profile, inner_wins, middle_wins, src, offset_mid)

            with rasterio.open(self.raster_out,'r') as src:

                mask = src.read()
                meta = src.meta

                contours, _ = cv2.findContours(
                    mask[0],
                    cv2.RETR_LIST,
                    cv2.CHAIN_APPROX_SIMPLE)
                polygons = self.polygonize(contours, meta)

                gpd.GeoDataFrame(polygons, columns=['geometry'],crs=meta['crs']).to_file(
                    os.path.join(self.output_path, 'geodf_enginery.geojson'), driver='GeoJSON')

    def polygonize(self, contours, meta):
        """Credit for base setup: Michael Yushchuk. Thank you!"""
        polygons = []
        for i in tqdm(range(len(contours))):
            c = contours[i]
            n_s = (c.shape[0], c.shape[2])
            if n_s[0] > 2:
                    polys = [tuple(i) * meta['transform'] for i in c.reshape(n_s)]

            polygons.append(Polygon(polys))
        return polygons

In [62]:
instance_predictor = IndexPredictor( raster_path = TILE_FOR_INFER_PATH, output_path = PATH_FOR_SAVE, weights_path = PATH_TO_MODEL, thresh =THRESHOLD, \
                                    input_width = 288, win_size = 288)

In [63]:
instance_predictor.segment_grove()

 03:24:53  segmenting grove


100% 480/480 [00:53<00:00,  9.02it/s]
  polys = [tuple(i) * meta['transform'] for i in c.reshape(n_s)]
100% 1/1 [00:00<00:00, 3506.94it/s]
  pd.Int64Index,
