In [None]:
import os
from tqdm import tqdm
from osgeo import gdal
import rasterio
import numpy as np
import cv2
from shapely.geometry import Polygon
import geopandas as gpd
from skimage.measure import find_contours as fc
from shapely.geometry import Polygon, LineString
import geopandas as gpd


from mmengine import Config
from mmseg.apis import init_model, inference_model, show_result_pyplot
import matplotlib.pyplot as plt

In [2]:
class RSI_inferencer:
    def __init__(self, input_path: str, output_path: str, input_file_extension: str = '.tif'):
        """Initializes the RSI_inferencer instance. 
        RSI_inferences is used for make predictions of test/valid data and geoferencing them 
        
        Args:
            input_path (str): The directory path where input images are located.
            output_path (str): The directory path where output images will be saved.
            input_file_extension (str, optional): The file extension of input images. Defaults to '.tif'.
        """
        
        self.input_path = input_path
        self.output_path = output_path
        self.input_file_extension = input_file_extension

        self.input_images = os.listdir(self.input_path)
        self.input_images = [x for x in self.input_images if input_file_extension == x[-len(input_file_extension):]]
        self.input_images.sort()

    def _inference_data(self, model, path_of_an_image):
        """Performs inference on a single image using a specified model.
        
        Args:
            model: The model to use for inference, mm_segmentation model is expected
            path_of_an_image (str): The file path of the image to perform inference on.
            
        Returns:
            tuple: A tuple containing the original image array and the inference results array.
        """
        
        ds = gdal.Open(path_of_an_image)
        new_rsi_image = np.array(ds.ReadAsArray()).astype(np.float32)
        new_rsi_image = new_rsi_image.transpose(1,2,0)

        results = inference_model(model, new_rsi_image)
        results = results.pred_sem_seg.data[0].cpu().numpy()
        
        return new_rsi_image, results
    
    def inference_input_data(self, model):
        """Performs inference on all input images using a specified model.
        
        Args:
            model: The model to use for inference.
        """

        for single_image_name in tqdm(self.input_images):
            input_file = os.path.join(self.input_path, single_image_name)
            
            _, results = self._inference_data(model, input_file)
            
            saving_path = os.path.join(self.output_path, single_image_name)
            output_file = self.write_geotiff(results, input_file, saving_path)

            
    def write_geotiff(self, out, raster_file, output_file):
        """Writes a georeferenced TIFF file using the output from the segmentation model.
        
        Args:
            out (ndarray): The output array from the segmentation model.
            raster_file (str): The file path of the original raster file.
            output_file (str): The file path where the georeferenced TIFF will be saved.
            
        """
        
        with rasterio.open(raster_file) as src:    
            ras_meta = src.profile
            ras_meta.update(count=1)
            with rasterio.open(output_file, 'w', **ras_meta) as dst:   
                dst.write(np.reshape(out, (1,) + out.shape)) 

In [3]:
input_dir_path = 'path to folder with input images to infenrence'
output_dir_path = 'path to folder where to save results'

In [4]:
path_config = 'path to config of mmsegmentation model'
path_checkpoint = 'path to checkpoint of mmsegmentation model'

In [None]:
cfg = Config.fromfile(path_config)
cfg.work_dir = "./work_dirs"
model = init_model(cfg, path_checkpoint, 'cuda:0')

In [18]:
inferencer = RSI_inferencer(input_dir_path, output_dir_path)


In [None]:
inferencer.inference_input_data(model)