In [378]:
import pickle
import os
import numpy as np
import rioxarray
import cv2
import geopandas as gpd
import pandas as pd
from pyproj import CRS
from shapely.geometry import Polygon
from shapely.ops import cascaded_union, unary_union

## Algorithm II: Instance Registration

Data structure: 
1. Each tile is noted with two spatial indices, i and j, which can be used to find its adjacent tiles.
2. Rocks are grouped by tiles and rocks' locations in a tile. Also, each rock is linked with its location and tile indices. Thus the indexing order is as follows: tile indices -> location -> rock index -> rock -> location & tile indices

Pseudo code:
```
for instance in instance:
    location, tile_indices = get_location_in_tile(instance)
    if location is not in tile_overlap:
        register(instance)
    else:
        adjacent_tiles = compute_adjacent_tiles(tile_indices)
        for adjacent_tile in adjacent_tiles:
            for adjacent_instance in adjacent_tile:
                if overlap_bbox(instance, adjacent_instance) > bbox_threshold:
                    if overlap_polygon(instance, adjacent_instance) > polygon_threshold:
                        merge(instance, adjacent_tile)
                        next instance # return
        register(instance)
```

In [495]:
class Instance_Registration(object):
    def __init__(self, instance_dir, 
                 save_shapefile, 
                 overlap_ratio=0.15, 
                 detection_threshold=0.75, 
                 segmentation_threshold=0.5, 
                 mask_overlap=0.1):
        
        assert os.path.exists(instance_dir)
        tile_files = [os.path.join(instance_dir, f) for f in os.listdir(instance_dir) if f.endswith('.pickle')]
        self.tiles = {}
        self.instances = []  # 
        tile_data = self._get_instance(tile_files[0])
        _, _, self.h, self.w = masks = tile_data['masks'].shape
        self.overlap = self.h * overlap_ratio
        self.mask_overlap = mask_overlap
        
        for tile_file in tile_files:
            tile_data = self._get_instance(tile_file)
            masks = tile_data['masks']
            instance_N = masks.shape[0]
            if instance_N == 0:
                continue
            masks = np.squeeze(masks, axis=1)
            tif_name = tile_data['image_name']
            indices = tuple([int(i) for i in tif_name.split('/')[-1].split('.')[0].split('_')])
            tif = rioxarray.open_rasterio(tif_name)
            _, self.tiff_h, self.tiff_w = tif.shape
            epsg = tif.rio.crs.to_epsg()
            self.crs = CRS(epsg)
            # post processing: detection confidence filter
            detect_scores = tile_data['scores']
            masks = masks[detect_scores>detection_threshold]
            for mask in masks:
                # post processing: segmentation confidence filter
                mask = mask > segmentation_threshold
                # post processing: contour analysis
                contours, _ = cv2.findContours(mask.astype(np.uint8).copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
                if len(contours) > 1:
                    areas = [cv2.contourArea(cnt) for cnt in contours]
                    i = np.argmax(areas)
                    contour = contours[i]
                    mask = np.zeros_like(mask).astype(np.uint8)
                    cv2.drawContours(mask, contour, -1, (255),1)
                    mask = mask > 200
                else:
                    contour = contours[0]
                # get location using pixel coords: top, bottom, left, right, middle
                contour = np.squeeze(contour, axis=1)
                locations = self._get_locations(contour)
                # convert to geopanda polygon
                poly = self._convert_poly(contour, tif)
                # instance registration
                instance = {'geometry': poly, "locations": locations, "indices": indices}
                self._instance_registration(instance)
                #self.instances.append(instance)
        
        if len(self.instances) > 0:
            dataframesList = [gpd.GeoDataFrame(crs=self.crs, geometry=[instance['geometry']]) for instance in self.instances]
            rdf = gpd.GeoDataFrame(pd.concat(dataframesList, ignore_index=True))
            rdf.to_file(save_shapefile)
                
    def _instance_registration(self, instance):
        if 'middle' in instance['locations']:
            self._add_instance(instance)
            return None
        
        (x, y) = instance['indices']
        merged = False
        for location in instance['locations']:
            if location == 'left':
                adjacent_indices = (x-1, y)
                adjacent_location = 'right'
                merged = merged | self._merge_instance(instance, location, adjacent_indices, adjacent_location)
            elif location == 'right':
                adjacent_indices = (x+1, y)
                adjacent_location = 'left'
                merged = merged | self._merge_instance(instance, location, adjacent_indices, adjacent_location)
            elif location == 'top':
                adjacent_indices = (x, y+1)
                adjacent_location = 'bottom'
                merged = merged | self._merge_instance(instance, location, adjacent_indices, adjacent_location)
            elif location == 'bottom':
                adjacent_indices = (x, y-1)
                adjacent_location = 'top'
                merged = merged | self._merge_instance(instance, location, adjacent_indices, adjacent_location)
        
        if not merged:
            self._add_instance(instance)
                
    def _add_instance(self, instance):
        self.instances.append(instance)
        instance_id = len(self.instances) - 1
        # update tile table
        if not self.tiles.get(instance['indices'], False):
            # initialize tile
            empty_locations = {'left':[], 'right':[], 'top':[], 'bottom':[], 'middle':[]}
            self.tiles[instance['indices']] = empty_locations
        # update tile
        for location in instance['locations']:
            self.tiles[instance['indices']][location].append(instance_id)
                        
    def _merge_instance(self, instance, location, adjacent_indices, adjacent_location):
        if not self.tiles.get(adjacent_indices, False): 
            return False
        else:
            for adjacent_id in self.tiles[adjacent_indices][adjacent_location]:
                adjacent_instance = self.instances[adjacent_id]
                gpd_poly = gpd.GeoDataFrame(crs=self.crs, geometry=[instance['geometry']])
                adjacent_gpd_poly = gpd.GeoDataFrame(crs=self.crs, geometry=[adjacent_instance['geometry']])
                if self._bbox_intersection(gpd_poly, adjacent_gpd_poly):
                    overlap = self._poly_intersection(adjacent_gpd_poly, adjacent_gpd_poly)
                    if overlap > self.mask_overlap:
                        poly_list = [instance['geometry'].buffer(0), adjacent_instance['geometry'].buffer(0)]
                        union_results = unary_union(poly_list)
                        if union_results.type == "MultiPolygon":
                            all_xy = []
                            for poly in union_results:
                                all_xy.append(np.asarray(poly.exterior.coords.xy))
                            poly_coords = np.concatenate(all_xy, axis=1).transpose()
                            union_poly = Polygon(poly_coords)
                            self.instances[adjacent_id]['geometry'] = union_poly
                        else:
                            union_poly = Polygon(union_results)
                            self.instances[adjacent_id]['geometry'] = union_poly
                        # update tile?? because the adjacent locations have been expanded, 
                        # but updating tiles may cause some issues?
                        return True
        return False
    
    def _get_instance(self, tile_file):
        with open(tile_file, 'rb') as handle:
            tile_data = pickle.load(handle)
        return tile_data
    
    def _get_locations(self, contour):
        x1, y1 = np.min(contour, axis=0)
        x2, y2 = np.max(contour, axis=0)
        locations = []
        if x1 < self.overlap:
            locations.append('left')
        if x2 > self.h - self.overlap:
            locations.append('right')
        if y1 < self.overlap:
            locations.append('top')
        if y2 > self.w - self.overlap:
            locations.append('bottom')
        if len(locations) == 0:
            locations.append('middle')
        return locations
    
    def _convert_poly(self, contour, tif):
        h_size, v_size = tif.rio.resolution()
        h_size = h_size * self.tiff_h / self.h
        v_size = v_size * self.tiff_w / self.w
        h_start, _, _, v_start= tif.rio.bounds()
        contour = contour.reshape(-1, 2).tolist()
        coords = [(h_start + pixel[0]*h_size, v_start + pixel[1]*v_size) for pixel in contour]
        coords = np.asarray(coords)
        poly = Polygon(zip(coords[:, 0].tolist(), coords[:, 1].tolist()))
        return poly
    
    def _poly_intersection(self, a, b):
        intersection = gpd.overlay(a,b, how='intersection')
        return intersection.area.to_numpy()
    
    def _bbox_intersection(self, a, b):
        (xmin_a, ymin_a, xmax_a, ymax_a) = np.asarray(a.bounds)[0]
        (xmin_b, ymin_b, xmax_b, ymax_b) = np.asarray(b.bounds)[0]
        if xmin_a < xmax_b <= xmax_a and (ymin_a < ymax_b <= ymax_a or ymin_a <= ymin_b < ymax_a):
            return True
        elif xmin_a <= xmin_b < xmax_a and (ymin_a < ymax_b <= ymax_a or ymin_a <= ymin_b < ymax_a):
            return True
        elif xmin_b < xmax_a <= xmax_b and (ymin_b < ymax_a <= ymax_b or ymin_b <= ymin_a < ymax_b):
            return True
        elif xmin_b <= xmin_a < xmax_b and (ymin_b < ymax_a <= ymax_b or ymin_b <= ymin_a < ymax_b):
            return True
        else:
            return False

In [496]:
ir = Instance_Registration('data/rocklas/prediction_2d', 'data/rocklas/prediction_shapefiles/merged_inference_rock.shp')



In [497]:
print(len(ir.instances))

253
