# Implementation of TeranausV2 Solution from DeepGlobe Building Extraction Competition
# This demo will work with any GeoTiff (S3 based Cloud Optimized GeoTiffs)


In [None]:
%load_ext autoreload
%autoreload 2

from cw_tiler import main
from cw_tiler import utils
from cw_tiler import vector_utils
from cw_nets.Ternaus_tools import tn_tools 
from cw_nets.tools import util as base_tools

In [None]:

# Each grid cell will be 200 on a side
cell_size_meters   = 200

# Each grid starting point will be spaced 150m apart
stride_size_meters = 190

# Specify the number of pixels in a tile cell_size_meters/tile_size_pixels == Pixel_Size_Meters
tile_size_pixels   = 650
# Calculate Anchor Points List
# Generate list of cells to read from utm_bounds 
## Process SpaceNet based Cloud Optimized GeoTiff
rasterPath = "s3://spacenet-dataset/AOI_2_Vegas/srcData/rasterData/AOI_2_Vegas_MUL-PanSharpen_Cloud.tif"
dataLocation = "/home/dlindenbaum/"
outputName = "AOI_2_Vegas_buildings_v7.tif"

smallExample=True # If true, only perform analysis on center 1km of area
## Load Pytorch Model

# The model is obtained from https://github.com/ternaus/TernausNetV2
modelPath = '/home/dlindenbaum/cosmiqGit/TernausNetV2/weights/deepglobe_buildings.pt'

In [None]:
rasterBounds, dst_profile = base_tools.get_processing_details(rasterPath, smallExample=True)



In [None]:
cells_list_dict = base_tools.generate_cells_list_dict(rasterBounds, cell_size_meters, stride_size_meters, tile_size_pixels)

In [None]:
mask_dict_list = base_tools.createRasterMask(rasterPath, cells_list_dict, dataLocation, outputName, dst_profile, modelPath, tile_size_pixels)


 Next we calculate information about the target Raster:

    * The raster bounds in wgs84 (lat, long)
    * The raster bounds in UTM Coordinates (x, y meters)
    * A VRT profile which represents the image translated to UTM Coordinates (UTM Coordinates is an area square projection
    
    

In [None]:
import os
resultDict = base_tools.process_results_mask(mask_dict_list, 
                                            os.path.join(dataLocation, outputName), 
                                           )


Determine specific characteristics for Analysis grid.  
This model was trained with the expectation of receiving
* 650 x 650 pixel images
* representing ~200m x 200m in area
* 8 Bands representing Coastal, Blue, Green, Yellow, Red, Red Edge, NIR1, NIR2

We need to slide a window of ~200m in size accross the target Tiff and then recombine the results

* cell_size_meters:  The size of the analysis square (200m in this case)
* stride_size_meters: How far apart each grid anchor point should be.  If ~50m overlap is desired, than we would specify ~150m in this case
* tile_size_pixels: Target size of returned window (650 x 650 pixels)



In [None]:
results, src_profile = base_tools.polygonize_results_mask(resultDict)

base_tools.write_results_tojson(results, 
                                os.path.join(dataLocation, outputName.relace('.tif', '.geojson')))
             
        
        
    
    
    
     
    
            


                
                

In [None]:
%%time
results = []
#mask= data_mask==0
outputTifMask = os.path.join(dataLocation, outputName.replace('.tif', 'Final_mask.tif'))

In [None]:
with rasterio.open(outputTifMask) as src:
    src_data_final = src.profile
    image = src.read(1)
    mask= src.dataset_mask()
    #for i, (s, v) in tqdm(enumerate(rasterio.features.shapes(image, transform=src_mask_profile['transform']))):
    #    results.append({'properties': {'raster_val': v}, 'geometry': s})

In [None]:
src_data_final
#print(image.sum())
#print(mask.sum())
print(image[image==0].sum())

In [None]:
%%time
results = []
#mask= data_mask==0
with rasterio.open(outputTifMask) as src:
    
    image = src.read(1)
    mask=src.dataset_mask()
    for i, (s, v) in tqdm(enumerate(rasterio.features.shapes(image, mask=mask, transform=src.transform))):
        results.append({'properties': {'raster_val': v}, 'geometry': s})

In [None]:
%%time
results = []
#mask= data_mask==0
with rasterio.open(outputTifMask) as src:
    
    image = src.read(1)
    mask=image>0
    for i, (s, v) in tqdm(enumerate(rasterio.features.shapes(image, mask=mask, transform=src.transform))):
        results.append({'properties': {'raster_val': v}, 'geometry': s})

In [None]:
mask_test = image>0

In [None]:
mask==mask_test

In [None]:
rasterio.features.dataset_features(src, bidx=None, sampling=1, band=True, as_mask=False, with_nodata=False, geographic=True, precision=-1)

In [None]:
%%time
for cells_list_id, cells_list in cells_list_dict.items():
        
        outputTifMask = os.path.join(dataLocation, outputName.replace('.tif', '{}_mask.tif'.format(cells_list_id)))
        outputTifCountour = os.path.join(dataLocation, outputName.replace('.tif', '{}_contour.tif'.format(cells_list_id)))
        outputTifCount = os.path.join(dataLocation, outputName.replace('.tif', '{}_count.tif'.format(cells_list_id)))


        with rasterio.open(outputTifMask) as src_mask, \
                rasterio.open(outputTifCountour) as src_seed:
            src_mask_profile = src_mask.profile
            data_mask = np.memmap('mask{}.mymemmap', dtype='uint8', mode='w+', 
                                     shape=(1, src_mask_profile['height'], src_mask_profile['width'])
                                    )
            data_mask = src_mask.read()
            del data_mask

            src_seed_profile = src_seed.profile
            data_seed = np.memmap('seed{}.mymemmap'.format(cells_list_id), dtype='uint8', mode='w+', 
                                     shape=(1, src_seed_profile['height'], src_seed_profile['width'])
                                    )
            data_seed = src_seed.read()
            del data_seed
    

In [None]:
%%time
outputTifMask = os.path.join(dataLocation, outputName.replace('.tif', '_mask.tif'))
outputTifCountour = os.path.join(dataLocation, outputName.replace('.tif', '_contour.tif'))
outputTifCount = os.path.join(dataLocation, outputName.replace('.tif', '_count.tif'))
import cv2


with rasterio.open(outputTifMask) as src_mask, \
        rasterio.open(outputTifCountour) as src_seed:
    src_mask_profile = src_mask.profile
    src_seed_profile = src_seed.profile

print("data_mask")
data_mask = np.memmap('mask.mymemmap', dtype='uint8', mode='r', 
                             shape=(1, src_mask_profile['height'], src_mask_profile['width'])
                            )    

print("data_seed")
data_seed = np.memmap('seed.mymemmap', dtype='uint8', mode='r', 
                             shape=(1, src_seed_profile['height'], src_seed_profile['width'])
                            )

    
print("start")
ret, markers = cv2.connectedComponents(np.squeeze(data_seed))
print("watershed")
markers = cv2.watershed(np.asarray([data_mask, data_mask, data_mask]).astype(np.uint8),markers)

print(np.min(markers))    

    