   Author: Ankit Kariryaa, University of Bremen

   Modified by Jiawei Wei

In [151]:
from tensorflow.keras.models import load_model
import os
import geopandas as gps
import rasterio                  # I/O raster data (netcdf, height, geotiff, ...)
import rasterio.warp             # Reproject raster samples
from rasterio import windows
# import fiona                     # I/O vector data (shape, geojson, ...)
# import geopandas as gps

from shapely.geometry import Point, Polygon
from shapely.geometry import mapping, shape

import numpy as np               # numerical array manipulation
import os
from tqdm import tqdm
import PIL.Image
import PIL.ImageDraw

from itertools import product
# from tensorflow.keras.models import load_model


import sys
from core.UNet import UNet
from core.losses_FTL import focalTversky, accuracy, dice_coef, dice_loss, specificity, sensitivity, PA, IoU_Pos, IoU_Neg, mIoU, F1_Score
from core.optimizers import adaDelta, adagrad, adam, nadam
from core.frame_info import FrameInfo, image_normalize
from core.dataset_generator import DataGenerator
from core.split_frames import split_dataset
from core.visualize import display_images

%matplotlib inline
import matplotlib.pyplot as plt  # plotting tools
import matplotlib.patches as patches
# from matplotlib.patches import Polygon

import warnings                  # ignore annoying warnings
warnings.filterwarnings("ignore")
import logging
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

%reload_ext autoreload
%autoreload 2
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'

import tensorflow as tf
print(tf.__version__)

2.5.0-rc3


In [152]:
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto(
    #device_count={"CPU": 64},
    allow_soft_placement=True, 
    log_device_placement=False)
config.gpu_options.allow_growth = True
#config.gpu_options.per_process_gpu_memory_fraction = 0.7 
session = InteractiveSession(config=config)

In [153]:
# Required configurations (including the input and output paths) are stored in a separate file (such as config/RasterAnalysis.py)
# Please provide required info in the file before continuing with this notebook. 
 
from config import RasterAnalysis_withLocation
# In case you are using a different folder name such as configLargeCluster, then you should import from the respective folder 
# Eg. from configLargeCluster import RasterAnalysis

config = RasterAnalysis_withLocation.Configuration()

In [154]:
# Load a pretrained model
OPTIMIZER = adam
LOSS = focalTversky
# OPTIMIZER=tf.train.experimental.enable_mixed_precision_graph_rewrite(OPTIMIZER)
import os
model = load_model(config.trained_model_path, custom_objects={'focalTversky': LOSS, 'dice_coef': dice_coef, 'dice_loss':dice_loss, 'accuracy':accuracy , 'specificity': specificity, 'sensitivity':sensitivity, 'PA':PA, 'IoU_Pos':IoU_Pos, 'IoU_Neg':IoU_Neg, 'mIoU':mIoU, 'F1_Score':F1_Score}, compile=False)
model.compile(optimizer=OPTIMIZER, loss=focalTversky, metrics=[dice_coef, dice_loss, accuracy, specificity, sensitivity, PA, IoU_Pos, IoU_Neg, mIoU, F1_Score])

In [5]:
import pandas as pd

location_df = pd.read_excel(r'I:\results\SST\landsat\location.xlsx')
location_df

Unnamed: 0,Name,Lat,Lon,Location,Radius,Drainage,Country,Region,Capacity,Start_date,End_date,CP1,CP2,CP3,CP4,CP5,CP6,CP7,CP8
0,DayaBay-Lingao,22.6076,114.5641,Bay,225,Shallow,China,East Asia,5802,2002-02-26,NaT,2002-12-15,2010-07-15,2011-05-03,NaT,NaT,NaT,NaT,NaT
1,Yangjiang,21.7024,112.2713,Open,75,Deep,China,East Asia,6000,2013-12-31,NaT,2015-03-10,2015-10-18,2017-01-08,2018-05-23,2019-06-29,NaT,NaT,NaT
2,Changjiang,19.4737,108.8834,Open,175,Deep,China,East Asia,1202,2015-11-07,NaT,2016-06-20,NaT,NaT,NaT,NaT,NaT,NaT,NaT
3,Fuqing,25.4152,119.4365,Bay,125,Shallow,China,East Asia,5000,2014-08-20,NaT,2015-08-06,2016-09-07,2017-07-29,2020-11-27,NaT,NaT,NaT,NaT
4,Ningde,27.0524,120.2868,Bay,275,Shallow,China,East Asia,4072,2012-12-28,NaT,2014-01-04,2015-03-21,2016-03-29,NaT,NaT,NaT,NaT,NaT
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
69,Point_Beach,44.2822,-87.5327,Lake,175,Shallow,United States,North America,1182,1970-11-06,NaT,1972-08-02,NaT,NaT,NaT,NaT,NaT,NaT,NaT
70,Robert_E_Ginna,43.2805,-77.3082,Lake,125,Shallow,United States,North America,560,1969-12-02,NaT,NaT,NaT,NaT,NaT,NaT,NaT,NaT,NaT
71,Zion,42.4450,-87.7982,Lake,125,Deep,United States,North America,2080,1973-06-28,1998-02-13,1973-12-26,1998-02-13,NaT,NaT,NaT,NaT,NaT,NaT
72,Bruce_1,44.3448,-81.5784,Lake,275,Shallow,Canada,North America,6358,1977-01-14,NaT,1976-09-04,1977-12-12,1978-12-21,1984-12-02,1984-06-26,1986-02-22,1987-03-09,NaT


In [155]:
# Methods to add results of a patch to the total results of a larger area. The operator could be min (useful if there are too many false positives), max (useful for tackle false negatives)
def addTOResult(res, prediction, row, col, he, wi, operator = 'MAX'):
    currValue = res[row:row+he, col:col+wi]
    newPredictions = prediction[:he, :wi]
# IMPORTANT: MIN can't be used as long as the mask is initialed with 0!!!!! If you want to use MIN initial the mask with -1 and handle the case of default value(-1) separately.
    if operator == 'MIN': # Takes the min of current prediction and new prediction for each pixel
        currValue [currValue == -1] = 1 #Replace -1 with 1 in case of MIN
        resultant = np.minimum(currValue, newPredictions) 
    elif operator == 'MAX':
        resultant = np.maximum(currValue, newPredictions)
    elif operator == 'REPLACE':
        resultant = newPredictions    
# Alternative approach; Lets assume that quality of prediction is better in the centre of the image than on the edges
# We use numbers from 1-5 to denote the quality, where 5 is the best and 1 is the worst.In that case, the best result would be to take into quality of prediction based upon position in account
# So for merge with stride of 0.5, for eg. [12345432100000] AND [00000123454321], should be [1234543454321] instead of [1234543214321] that you will currently get. 
# However, in case the values are strecthed before hand this problem will be minimized
    res[row:row+he, col:col+wi] =  resultant
    return (res)

In [156]:
# Methods that actually makes the predictions
def predict_using_model(model, batch, batch_pos, mask, operator):
    tm = np.stack(batch, axis = 0)
#     print('tm', tm.shape)
    prediction = model.predict(tm)
#     print('prediction', prediction.shape)
    for i in range(len(batch_pos)):
        (col, row, wi, he) = batch_pos[i]
        p = np.squeeze(prediction[i], axis = -1)
#         print('p', p.shape)
        # Instead of replacing the current values with new values, use the user specified operator (MIN,MAX,REPLACE)
        mask = addTOResult(mask, p, row, col, he, wi, operator)
    return mask
    

def detect_plume(WST_img, width=256, height=256, stride = 128, normalize=False): 
    nols, nrows = WST_img.meta['width'], WST_img.meta['height']
    meta = WST_img.meta.copy()
    if 'float' not in meta['dtype']: #The prediction is a float so we keep it as float to be consistent with the prediction. 
        meta['dtype'] = np.float32
    offsets = product(range(0, nols, stride), range(0, nrows, stride))
    big_window = windows.Window(col_off=0, row_off=0, width=nols, height=nrows)
#     print(nrows, nols)
    print('the size of current cWST_img',nrows, nols) #（40000，40000）
    
    mask = np.zeros((nrows, nols), dtype=meta['dtype'])

#     mask = mask -1 # Note: The initial mask is initialized with -1 instead of zero to handle the MIN case (see addToResult)
    batch = []
    batch_pos = [ ]
    for col_off, row_off in  tqdm(offsets):
        window =windows.Window(col_off=col_off, row_off=row_off, width=width, height=height).intersection(big_window)
        transform = windows.transform(window, WST_img.transform)
        patch = np.zeros((height, width, 2)) #Add zero padding in case of corner images00
        WST_sm = WST_img.read(window=window)
        WST_sm = np.nan_to_num(WST_sm, nan=-255)
#         print('WST_sm', WST_sm.shape)
        temp_im = np.stack(WST_sm, axis = -1) 
#         print('temp_im', temp_im.shape)
#         temp_im = np.squeeze(temp_im)
        
        if normalize:
            temp_im = image_normalize(temp_im, axis=(0,1)) # Normalize the image along the width and height i.e. independently per channel
            
        patch[:window.height, :window.width] = temp_im
#         print('patch', patch.shape)
        batch.append(patch)
        batch_pos.append((window.col_off, window.row_off, window.width, window.height))
        if (len(batch) == config.BATCH_SIZE):
            mask = predict_using_model(model, batch, batch_pos, mask, 'MAX')
            batch = []
            batch_pos = []
            
    # To handle the edge of images as the image size may not be divisible by n complete batches and few frames on the edge may be left.
    if batch:
        mask = predict_using_model(model, batch, batch_pos, mask, 'MAX')
        batch = []
        batch_pos = []

    return(mask, meta)

In [157]:
import cv2
from collections import defaultdict
import rasterio.features
schema = {
    'geometry': 'Polygon',
    'properties': {'id': 'str', 'plume': 'float:15.2',},
    }

# Generate a mask with polygons
def transformContoursToXY(contours, transform): # = None
    tp = []
    for cnt in contours:
        pl = cnt[:, 0, :]
        cols, rows = zip(*pl)
        x,y = rasterio.transform.xy(transform, rows, cols)
        if type(x)==np.float64:
            print(cols, rows, x, y, pl)
            tl = [(x, y)]
        else:
            tl = [list(i) for i in zip(x, y)]
        tp.append(tl)
    return (tp)

def mask_to_polygons(j, lon, lat, maskF, transform):
    # first, find contours with cv2: it's much faster than shapely
    th = 0.5
    mask = maskF.copy()
    mask[mask < th] = 0
    mask[mask >= th] = 1
    mask = ((mask) * 255).astype(np.uint8)
    
    contours, hierarchy = cv2.findContours(mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
    #Convert contours from image coordinate to xy coordinate
    contours = transformContoursToXY(contours, transform)
    if not contours: #TODO: Raise an error maybe
        print('Warning: No contours/polygons detected!!')
        print(maskF.max())
        return [Polygon()]
    # now messy stuff to associate parent and child contours
    cnt_children = defaultdict(list)
    child_contours = set()
    assert hierarchy.shape[0] == 1
    # http://docs.opencv.org/3.1.0/d9/d8b/tutorial_py_contours_hierarchy.html
    for idx, (_, _, _, parent_idx) in enumerate(hierarchy[0]):
        if parent_idx != -1:
            child_contours.add(idx)
            cnt_children[parent_idx].append(contours[idx])

    # create actual polygons filtering by area (removes artifacts)
    all_polygons = []
    all_polygons_area = []
    for idx, cnt in enumerate(contours):
        if idx not in child_contours and len(cnt)>=3: #and cv2.contourArea(cnt) >= min_area: #Do we need to check for min_area?? and cv2.pointPolygonTest(np.float32(cnt), (lon,lat), False) != -1
            try:

                poly = Polygon(
                    shell=cnt,
                    holes=[c for c in cnt_children.get(idx, [])]) #
                           #if cv2.contourArea(c) >= min_area]) #Do we need to check for min_area??
                
                point = Point(lon,lat)
#                 print(point)
#                 print(poly)
#                 if poly.intersects(point.buffer(0.01)):
                all_polygons.append(poly)
                all_polygons_area.append(poly.area)
            except Exception as e:
                print(e)
                pass
#                 print("An exception occurred in createShapefileObject; Polygon must have more than 2 points")
#     if j == 7:
#          print(all_polygons[0])
#     print(len(all_polygons))
#     print(*all_polygons, sep='\n')
    if len(all_polygons)>1:
        max_poly = max(all_polygons_area)
        idx_max = all_polygons_area.index(max_poly)
        all_polygons = [all_polygons[idx_max]]
#         dist_list = []
#         for idx, poly in enumerate(all_polygons):
#             dist = point.distance(poly)
#             dist_list.append(dist)
#         min_value = min(dist_list)
#         print(min_value)
#         min_idx = dist_list.index(min_value)
#         all_polygons = [all_polygons[min_idx]]
#         if min_value > 0.01:
#             all_polygons = [Polygon()]
        
    if len(all_polygons)==0:
        all_polygons = [Polygon()]
#     print(len(all_polygons))
#     print(*all_polygons, sep='\n')
    return(all_polygons)


def create_contours_shapefile(j, lon, lat, mask, meta, out_fn):
    res = mask_to_polygons(j, lon, lat, mask, meta['transform'])
#     res = transformToXY(contours, meta['transform'])
#     createShapefileObject(res, meta, out_fn)
    if not res[0].is_empty:
        location_mask = rasterio.features.rasterize(res, fill=0, out_shape=mask.shape, transform=meta['transform'], all_touched=True, default_value=1, dtype='uint8')
    else:
        location_mask = np.zeros(mask.shape)
    
    return(location_mask)


def writeMaskToDisk(j, lon, lat, WST, label, station_name, time_str, detected_mask, detected_meta, wp, write_as_type = 'uint8', th = 0.5, create_countors = True): #uint8
    # Convert to correct required before writing
    if 'float' in str(detected_meta['dtype']) and 'int' in write_as_type:
        print(f'Converting prediction from {detected_meta["dtype"]} to {write_as_type}, using threshold of {th}')
        detected_mask[detected_mask<th]=0
        detected_mask[detected_mask>=th]=1

        label_arr=label.read()
        label_arr=np.squeeze(label_arr, axis=0)
        
#         print(label_arr.shape)

        
        detected_mask[label_arr==0]=0  #mask pred results outside of the label

        
        detected_mask = detected_mask.astype(write_as_type)
        WST_arr=WST.read()[0]
#         print(WST_arr.shape)
        
#         WST_arr=np.squeeze(WST_arr, axis=0)
        WST_arr[detected_mask==0]=np.nan
        WST_arr[WST_arr==0]=np.nan

        if station_name == 'Ohi':
            Ohi_mask = np.zeros(WST_arr.shape, dtype=write_as_type)
            Ohi_mask[WST_arr>10] = 1
            detected_mask[Ohi_mask==1] = 0
        elif station_name == 'Dungeness' and time_str == '19890630':
            detected_mask[:, :]= 0
        elif station_name == 'Vandellos' and time_str == '20100314':
            detected_mask[:, :]= 0
            
        print(detected_mask.dtype)
        detected_meta['dtype']=write_as_type
        
#         #transfer crs
#         src_crs=WST.crs
#         dst_crs={'init': 'EPSG:8857'} #equal earth projection
#         src_transform=WST.transform
#         dst_transform, width, height=cdt(src_crs, dst_crs, WST.width, WST.height, *WST.bounds)
#         kwargs=WST.meta.copy()
#         kwargs.update({
#             'crs': dst_crs,
#             'transform': dst_transform,
#             'width': width,
#             'height': height})
        
    
#     if create_countors:
#     wp = wp.replace(config.output_image_type, config.output_shapefile_type)
    location_mask = create_contours_shapefile(j, lon, lat, detected_mask, detected_meta, wp)
    location_mask = location_mask.astype(np.float32)
    if np.all((location_mask == 0)):
        WST_arr[:,:] = np.nan
    else:
        WST_arr[location_mask == 0] = np.nan
    
    
    if station_name == 'Ohi':
        WST_arr[WST_arr > 10] = np.nan 
    
    detected_meta['dtype'] = np.float32
    detected_meta['count'] = 1
    with rasterio.open(wp, 'w', **detected_meta) as outds:
        outds.write(WST_arr, 1)
#         reproject(
#             source=WST_arr,
#             destination=destination,
#             src_transform=src_transform,
#             src_crs=src_crs,
#             dst_transform=dst_transform,
#             dst_crs=dst_crs,
#             resampling=Resampling.nearest)


In [None]:
# Predict plumes in the all the files in the input image dir
# Depending upon the available RAM, images may not to be split before running this cell.
# Use the Auxiliary-2-SplitRasterToAnalyse if the images are too big to be analysed in memory.
all_files = []
all_labels = []
for root, dirs, files in os.walk(config.input_image_dir):
    files.sort()
    for file in files:
        if file.endswith(config.input_image_type) and file.startswith(config.WST_fn_st):
             all_files.append((os.path.join(root, file), file))
        if file.endswith(config.input_image_type) and file.startswith(config.label_fn_st):
             all_labels.append((os.path.join(root, file), file))
pd.Series(all_files, all_labels)
for idx, (fullPath, filename) in enumerate(all_files):
    lb_path, lb_fn = all_labels[idx]
    outputFile = os.path.join(config.output_dir, filename.replace(config.WST_fn_st, config.output_prefix) )
    if not os.path.isfile(outputFile) or config.overwrite_analysed_files: #isfile function check whether the path is existing or not
        with rasterio.open(fullPath) as WST: 
            with rasterio.open(lb_path) as label:
                label.meta['dtype']=np.uint8 
                idImg_sub1 = filename.find('_', 6)
                idImg_sub2 = filename.rfind('L')
                station_name = filename[idImg_sub1+1:idImg_sub2-1]
                time_str = filename[-12:-4]
                print(filename)
                lon = location_df.loc[location_df['Name']==station_name,['Lon']].values[0][0]
                lat = location_df.loc[location_df['Name']==station_name,['Lat']].values[0][0]
                detectedMask, detectedMeta = detect_plume(WST, width = config.WIDTH, height = config.HEIGHT, stride = config.STRIDE, normalize = False) # WIDTH and HEIGHT should be the same and in this case Stride is 50 % width

        #Write the mask to file 
                writeMaskToDisk(idx, lon, lat, WST, label, station_name, time_str, detectedMask, detectedMeta, outputFile, write_as_type = config.output_dtype, th = 0.5, create_countors = False)
                
    else:
        print('File already analysed!', fullPath)

In [None]:
# Display extracted image
sampleImage = ''
fn = os.path.join(config.output_dir, config.output_prefix + sampleImage )
predicted_img = rasterio.open(fn)
p = predicted_img.read()
np.unique(p, return_counts=True)
plt.imshow(p[0])