In [61]:
from pathlib import Path
import os
import json
import warnings

import rasterio
import shapely
import cv2
import torch
import numpy as np
from rasterio.plot import reshape_as_raster, reshape_as_image
from rasterio.windows import Window
from tqdm import tqdm
from shapely.geometry import box, Polygon
import geopandas as gpd

PATH_PREPROCESSED_TILE = '/home/quantum/result/results/planet_downloader/0/Kharkiv_SkySat_07-05-22.tif'


In [62]:
warnings.filterwarnings("ignore")

In [55]:
head, tail = os.path.split(PATH_PREPROCESSED_TILE)
OUTPUT_PATH =  head
MODEL_PATH = '/home/quantum/multidetection/best.pt'
FONT = cv2.FONT_HERSHEY_SIMPLEX
FONT_SCALE = 0.5
COLOR = (255, 0, 0)
THICKNESS = 1
STD_NORMALIZE = 3.5
ORDER_CHANNELS = (1,2,3)
STEP = 512
VISUALIZE = False

In [66]:
def get_yolo_predict(model_path, raster_path, output_path, bands_order, step=512, visualize=False):
    print('Model uploading... \n')
    model = torch.hub.load(
        "ultralytics/yolov5:master", "custom", model_path, verbose=True)
    print('Model has been uploaded \n')
    
    geo_save_path = os.path.join(output_path, 'detected_objects.geojson')
    dst_raster_path = os.path.join(output_path, 'predict.tif')
    
    src = rasterio.open(raster_path)
    w, h = src.meta['width'], src.meta['height']
    if w < step and h < step:
        step = np.min([w, h])

    whole_rem_w = divmod(w, step)
    whole_rem_h = divmod(h, step)

    all_steps_h = [(0, i * step, 0, step) for i in range(whole_rem_h[0])]

    if whole_rem_h[1] != 0:
        all_steps_h = all_steps_h + [(0, all_steps_h[-1][1] + step, 0, whole_rem_h[1])]

    all_steps = []
    for h_step in all_steps_h:
        all_steps = all_steps + [(i * step, h_step[1], step, h_step[-1]) for i in range(whole_rem_w[0])]
        if whole_rem_w[1] != 0:
            all_steps = all_steps + [(all_steps[-1][0] + step, h_step[1], whole_rem_w[1], h_step[-1])]
    profile = src.profile
    profile['dtype'] = 'uint8'
    profile['count'] = 3
    profile['nodata'] = 0
    with rasterio.open(
            dst_raster_path, 'w', **profile
    ) as dst:
        print('Start of predictions...\n')
        detections = []
        for i in tqdm(all_steps):
            x_start ,y_start, x_step, y_step = i
            window_normalize = src.read(bands_order, window=Window(x_start ,y_start, x_step, y_step))
            
            image = reshape_as_image(np.array(window_normalize)).copy()
            preds = model(image, size=512, augment=True)
            ans = preds.pandas().xyxy[0]
            ans["w"] = (ans.xmax - ans.xmin).astype(int)
            ans["h"] = (ans.ymax - ans.ymin).astype(int)
            ans["x"] = ans.xmin.astype(int)
            ans["y"] = ans.ymin.astype(int)
            for j, pred in ans.iterrows():
                x, y, w, h, confidence, label = pred['x'], pred['y'], pred['w'], pred['h'], pred['confidence'], pred[
                    'name']
                x,y = x + x_start, y + y_start
                coords_box = [(x,y),(x,y+h),(x+w,y+h),(x+w,y)]
                polygon = Polygon([src.transform*box_cs for box_cs in coords_box] )
                detections.append({'geometry':polygon, 'label':label})
                if visualize:    
                    cv2.rectangle(image, (x, y), (x + w, y + h), (51, 255, 51), 2)

                    cv2.putText(image, label, (x, y), FONT,
                                FONT_SCALE, COLOR, THICKNESS, cv2.LINE_AA)
            if visualize: 
                dst.write(reshape_as_raster(image), window=Window(*i))
                
        gpd.GeoDataFrame(detections,crs=src.meta['crs']).to_file(
                    geo_save_path, driver='GeoJSON')
        print(f'Predictions are saved in :  {geo_save_path} \n')
            
        if visualize: 
            print(f'Visualization of predictions  is ready! It is in :  {dst_raster_path} \n')
    

In [None]:
get_yolo_predict(MODEL_PATH, PATH_PREPROCESSED_TILE, OUTPUT_PATH, ORDER_CHANNELS, step=STEP, visualize=VISUALIZE)