# Part 6: Segmentation with new orthophotos

Within this notebook, the previously trained model for pedestrian refuge island segmentation can be further used on new aerial images. In the end, the predicted coordinates of the objects are again stored in a geoJSON and are ready for comparison.

Precondition for this part:
- The TIF images must be tagged with the Spherical Mercator projection coordinates. Otherwise, the coordinates will be invalid. Alternatively, one can manually change the EPSG Code in the transformer function.

In [None]:
import os
import rasterio as rio
import cv2
import tifffile
from fastai.vision.all import *
from geojson import Feature, Point, FeatureCollection, dump
from skimage import io, exposure
from skimage.transform import resize
from skimage.color import rgba2rgb
from skimage.util import img_as_ubyte
from multiprocessing import Pool
from functools import partial
from pyproj import Transformer

#### Paths, Directories and Downloads

In [None]:
CURRENT_PATH = Path(os.getcwd()) 
DATASET_PATH = CURRENT_PATH / "data"
IMAGES_PATH = DATASET_PATH / 'islands'
PRED_MASK_PATH = DATASET_PATH / 'pred_mask'
PRED_MASK_PATH_TFMS = DATASET_PATH / 'pred_mask_tfms'

# Create Directory
if not DATASET_PATH.exists():
    os.mkdir(DATASET_PATH)
    os.mkdir(IMAGES_PATH)
    os.mkdir(PRED_MASK_PATH)
    os.mkdir(PRED_MASK_PATH_TFMS)
    print('Directories created!')
# Download exporter
if not CURRENT_PATH.ls(file_exts='.pkl'):
    urllib.request.urlretrieve(url=r'https://drive.switch.ch/index.php/s/9W2w1m5k0mENn7F/download',filename=CURRENT_PATH/'island_segmentation_resnet18.pkl') 
    print('Learner downloaded!')

#### 6 Steps From Orthophoto (TIFF) to GeoJSON

1. Provide Images as TIFF into directory: _/PT1-Refuge_Islands/data/islands
2. TIFF (with Spherical Mercator projection coordinates) converted into PNG
3. Lernear imported and set up 
4. Masks predicted
5. Adjusted mask size
6. Coordinates maped, extracted and saved as GeoJSON

In [None]:
# 2. Convert provided images from TIFF to PNG

def read_image(img, channels=None):
    with tifffile.TiffFile(img) as tif:
        im = tif.asarray()
    if channels:
        im = im[:, :, channels]
    return im

def reduce_image_array(i_array):
    image = i_array/i_array.max()
    image = img_as_ubyte(image)
    return image

def save_img_as_png(im_file, dest_folder, channels=[2,1,0]):
    file_out = dest_folder / f"{im_file.stem}.png"
    im = read_image(im_file, channels=channels)
    im = reduce_image_array(im)
    io.imsave(file_out.absolute(), im)
    return file_out

def split_list(l, number_of_parts):
    splitted = np.array(np.array_split(l, number_of_parts), dtype=object)
    return [s.tolist() for s in splitted]

def convert_list_of_images(image_list, out_folder):
    for im_path in image_list:
        save_img_as_png(im_path, out_folder)

        
# Verify tif exist
if any(list((IMAGES_PATH.glob("*.tif")))):    

    # Create png folder
    IMAGES_PNG_PATH = DATASET_PATH / 'images_png'
    if not IMAGES_PNG_PATH.exists():
        os.mkdir(IMAGES_PNG_PATH)
    else:
        usr_input = input('PNG folder already existis. Delete? (Y/N): ').upper()
        if usr_input == ('Y'):
            shutil.rmtree(IMAGES_PNG_PATH)
            os.mkdir(IMAGES_PNG_PATH)  
        else:
            exit()

    number_of_parallelism = 10

    # Get images in a list []
    image_list = list(IMAGES_PATH.glob("*.tif"))

    # Split list into n sublist [[][]]
    image_list_splitted = split_list(image_list, number_of_parallelism)

    do_conversion = partial(convert_list_of_images, out_folder=IMAGES_PNG_PATH)

    with Pool(number_of_parallelism) as p:
        p.map(do_conversion, image_list_splitted)
    print('Images converted!')

In [None]:
# 3. Import and set up learner

# Metric
def calc_accuracy(inp, targ):
  targ = targ.squeeze(1)
  mask = targ == 255
  return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean()

# Label
def label_func(fn):
    return TEST_MASKS_PATH / f"{fn.stem}.png"


# Codes
codes = 255 * ["not_island"]
codes.append('island')

# Image Augmentation
batch_tfms = aug_transforms(flip_vert=True, max_lighting=0.1, max_zoom=1.05, max_warp=0.)

# Learner
learner = load_learner(Path('island_segmentation_resnet18.pkl'))

In [None]:
# 4. Predict Masks based on PNG's

def pred_mask():
    cnt = 1
    print('Prediction started..')
    for img in IMAGES_PNG_PATH.ls():

        # Predict mask based on original test image
        pred_msk_array, pred_idx, outputs = learner.predict(item=img)

        # Save predicted mask
        pred_msk = PRED_MASK_PATH/ ('pred_'+ img.name)
        pred_mask_img = PILMask.create(pred_msk_array)
        pred_mask_img.save(pred_msk, format="png")
        
        print('item ' + str(cnt) + '/' + str(len(IMAGES_PNG_PATH.ls())))
        cnt +=1
        
pred_mask()

In [None]:
# 5. Align size of predicted Masks

def get_original_shape(item_name,org_dir):    
    for img in IMAGES_PATH.iterdir():
        if item_name == img.stem:
            with rio.open(img) as dataset: 
                return dataset.shape

def resize_items(item,x,y): 
    resize_item = item.resize(torch.Size([x, y]))
    return resize_item

def threshold_item(item):
    "Convert all values between 0 and 255 to either 0 or 255"
    pred_msk_array_resized = array(item)
    return PILImage.create(pred_msk_array_resized.clip(0, 1).astype("uint8") * 255)

def verify_shape_size(resized_pred_dir,org_dir):
    print('Verifying shape sizes...')
    all_match = True
    for cnt,resized_msk in enumerate(PRED_MASK_PATH_TFMS.ls()):  
        for original_img in IMAGES_PATH.ls():           
            if resized_msk.stem[5:] == original_img.stem:
                resized_pred_msk_img = Image.open(resized_msk)    
                with rio.open(original_img) as dataset: 
                    if dataset.shape != resized_pred_msk_img.shape:
                        print(cnt, 'ATTENTION NOT SAME SHAPE!')
                        print('Resized predicted Mask: ', resized_msk.name, '  Shape: ', resized_pred_msk_img.shape)
                        print('Original Image: ', original_img.name, '  Shape: ', resized_pred_msk_img.shape, '\n')
                        all_match = False  

    if all_match:
        print('All shapes match!')  


def align_pred_to_original_shape(pred_dir, org_dir, resized_pred_dir_out):
    
    # Check directory
    if not PRED_MASK_PATH_TFMS.exists():
        os.mkdir(PRED_MASK_PATH_TFMS)
    
    print('Resizing predicted mask..')
    
    # Align pred mask shape with original image shape and save mask into new folder
    for pred_msk in PRED_MASK_PATH.ls():
        pred_msk_img = Image.open(pred_msk)          
        org_img_shape = get_original_shape(pred_msk.stem[5:],org_dir)
        pred_msk_img_resized= resize_items(item=pred_msk_img,x=org_img_shape[1],y=org_img_shape[0])
        pred_msk_img_resized_threshold=threshold_item(item=pred_msk_img_resized)
        pred_msk_img_resized_threshold.save(resized_pred_dir_out/pred_msk.name)

    verify_shape_size(resized_pred_dir_out,org_dir)

align_pred_to_original_shape(PRED_MASK_PATH,IMAGES_PATH,PRED_MASK_PATH_TFMS)

In [None]:
# 6. Map and extract coordinates into GeoJSON

def get_pixel_coordinates(pred_mask):
    crossings = list()
    
    # Convert Mask into 2dimensional array    
    gray_image = cv2.cvtColor(cv2.imread(str(pred_mask)) , cv2.COLOR_BGR2GRAY)

    # Find contours in the binary image
    contours, hierarchy = cv2.findContours(image=gray_image,mode=cv2.RETR_TREE,method=cv2.CHAIN_APPROX_SIMPLE)

    for i,c in enumerate(contours):
        
        if cv2.contourArea(c) > 210: # Exclude small marks    
            crossing = np.zeros(gray_image.shape[:2], np.uint8)
            cv2.drawContours(crossing, [c], -1, (255), -1)
            loc, dims, angle = cv2.minAreaRect(c) 
    
            # Calculate moments for each contour
            M = cv2.moments(c)
            
            # Calculate x,y coordinate of center
            cX = int(M["m10"] / M["m00"])
            cY = int(M["m01"] / M["m00"])           
            
            crossings.append([cX,cY])
    
    return crossings


def get_real_coordinates(data):    
    x,y = convert_coordinates(data[0][0],data[0][1])        
    return x,y


def convert_coordinates(x,y):
    """
    EPSG:3857 --> Spherical Mercator projection coordinate 
    EPSG:4326 --> WGS84 
    """
    points = [(x,y)]
    transformer = Transformer.from_crs(3857, 4326)
    for pt in transformer.itransform(points): 
        return pt
    
    
def get_geo_json(coordinates):    
    "Taskes List with tuple coordinates and creates a geojson"   
    
    feature_list = list()
    
    for x,y in coordinates:
        feature_list.append(Feature(geometry=Point((y, x))))
  
    feature_collection = FeatureCollection(feature_list)
    
    return feature_collection
                      
    

def main():
    cnt = 1
    real_coordinates = list()
    print('Mappingprocess started..')
    
    # Get mask and real image
    for cnt,resized_msk in enumerate(PRED_MASK_PATH_TFMS.ls()):        
        for original_img in IMAGES_PATH.ls():           
            if resized_msk.stem[5:] == original_img.stem:
                with rio.open(original_img) as dataset: 
                    
                    # Get pixel coordinates 
                    crossings = get_pixel_coordinates(resized_msk)
                    
                    for crossing in crossings:
                        cX,cY = crossing
                        val = dataset.read(4)
                        no_data=dataset.nodata  
                        if val[cX,cY] != no_data:
                            data = [(dataset.xy(cY,cX)[0],dataset.xy(cY,cX)[1],val[cY,cX])]
                            real_coordinates.append((get_real_coordinates(data)))
        print('item ' + str(cnt) + '/' + str(len(PRED_MASK_PATH_TFMS.ls())))
        cnt +=1
    geo_json = get_geo_json(real_coordinates)
    
    with open('island_coordinates.geojson','w') as f:
        dump(geo_json,f)

    return geo_json
    
    
main()