In [28]:
# WRITTEN BY ALAN LI
# NASA AMES LABORATORY FOR ADVANCED SENSING (LAS)
# Last edited: Oct 29, 2018

import sys
sys.path.append("./utils/") # Adds higher directory to python modules path.
import random
import numpy as np
import cv2
import glob, os
from collections import Counter
from skimage.measure import compare_ssim as ssim

from matplotlib import pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import colors
from collections import Counter
import pandas as pd
import logging
import yaml
import json
from PIL import Image as pil_image
import scipy
from osgeo import gdal, ogr, osr

import tensorflow as tf
from tensorflow.python.client import device_lib
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import loadcoraldata_utils as coralutils
import keras
import keras.backend as K
from keras.models import Model, Sequential, load_model
from keras.utils import np_utils
from keras import optimizers
from keras.preprocessing.image import img_to_array, load_img
from keras.callbacks import Callback
from keras.losses import mean_squared_error
from keras.utils.np_utils import to_categorical
from keras.models import Model

from NeMO_models import FCN
from NeMO_losses import charbonnierLoss
import NeMO_layers
import NeMO_encoders
from NeMO_generator import NeMOImageGenerator, ImageSetLoader
from NeMO_callbacks import CheckNumericsOps, WeightsSaver

num_cores = 4
num_GPU = 1
num_CPU = 1

global _SESSION
config = tf.ConfigProto(intra_op_parallelism_threads=num_cores,
        inter_op_parallelism_threads=num_cores, allow_soft_placement=True,
        device_count = {'CPU' : num_CPU, 'GPU' : num_GPU})
_SESSION = tf.Session(config=config)
K.set_session(_SESSION)

%matplotlib notebook

model = load_model('./tmp/RefineMask_Cloud.h5', custom_objects={'BilinearUpSampling2D':NeMO_layers.BilinearUpSampling2D, 
                                                    'charbonnierLoss': charbonnierLoss})

In [29]:
def loadrasterRGB(filepath,channels):
    img = gdal.Open(filepath)

    xsize = img.RasterXSize
    ysize = img.RasterYSize

    image = np.zeros((ysize,xsize,len(channels)))

    for band in range(len(channels)):
        imgband = img.GetRasterBand(channels[band])
        image[:,:,band] = imgband.ReadAsArray()
    return image

def fixRGBimg(img,pixel_mean,pixel_std,minval,maxval):
    fixed_img = img
    fixed_img = fixed_img*pixel_std+pixel_mean
    fixed_img[fixed_img > maxval] = maxval
    fixed_img[fixed_img < minval] = minval
    fixed_img = np.asarray(fixed_img, dtype=np.uint8)
    return fixed_img

def raster_lores(filepath, downscale):
    img = gdal.Open(filepath)
    xsize = img.RasterXSize
    ysize = img.RasterYSize
    channels = img.RasterCount
     
    downscalesize_x = int(xsize*downscale)
    downscalesize_y = int(ysize*downscale)
    
    lores_image = np.zeros((downscalesize_y, downscalesize_x, channels))
    
    for band in range(channels):
        imgband = img.GetRasterBand(band+1).ReadAsArray()
        lores_image_channel = scipy.misc.imresize(imgband, downscale, mode='F')
        lores_image[:,:,band] = lores_image_channel
        
    return lores_image

In [70]:
# 4 channel code for Sentinel data
filepath = "/home/shared/NeMO-Net Data/cloudImages4training/CloudMask_Valid_Patches_4channel_256/Clouds/Clouds_00000007.tif"
filepath_cloudmask = "/home/shared/NeMO-Net Data/cloudImages4training/CloudMask_ValidRef_Patches_4channel_256/Clouds/Clouds_00000007.tif"
# Fiji_Sentinel = coralutils.CoralData(filepath_Sentinel, load_type="raster")
Fiji_Sentinel = coralutils.CoralData(filepath, Truthpath=filepath_cloudmask, load_type="raster")

# Just some visualizations of RGB
patch_R = 255/200*Fiji_Sentinel.image[:,:,0]
patch_G = 255/200*Fiji_Sentinel.image[:,:,1]
patch_B = 255/200*Fiji_Sentinel.image[:,:,2]
patch_R[patch_R > 255] = 255
patch_G[patch_G > 255] = 255
patch_B[patch_B > 255] = 255
Fiji_RGB = np.rollaxis(np.asarray([patch_R, patch_G, patch_B], dtype=np.uint8),0,3)

fig = plt.figure(1,figsize=(12,5))
ax = plt.subplot(1,2,1)
ax.imshow(Fiji_RGB)
ax = plt.subplot(1,2,2)
ax.imshow(Fiji_Sentinel.truthimage)
plt.show()

Truthimage geotransform is not set! Reverting to default image's geotransform...


<IPython.core.display.Javascript object>

In [71]:
def classifyback(predictions):
    return np.argmax(predictions,-1)

pixel_mean = 100
pixel_std = 100

testimg = Fiji_Sentinel.image
testimg = img_to_array(testimg, data_format=K.image_data_format())
testimg = (testimg-pixel_mean)/pixel_std
imgbatch = np.expand_dims(testimg, axis=0)
imgpredict = model.predict(imgbatch, batch_size = 1)[0]
imgpredict = np.reshape(classifyback(imgpredict), (256,256))

fig = plt.figure(2, figsize=(12,5))
ax1 = fig.add_subplot(1,1,1)
ax1.imshow(imgpredict)
ax1.set_title("Predicted")
plt.show()

print(np.unique(imgpredict))

imgtruth = Fiji_Sentinel.truthimage
imgtruth[imgtruth == 127] = 1
perc_correct = np.sum(imgpredict == imgtruth)/(256*256)
print(perc_correct)

<IPython.core.display.Javascript object>

[0 1]
0.9303741455078125


In [79]:
jsonpath = './utils/CoralClasses.json'
with open(jsonpath) as json_file:
    json_data = json.load(json_file)

imgpath = '/home/shared/NeMO-Net Data/processed_Fiji_files/mosaiced_001.TIF'
truthpath = '/home/shared/NeMO-Net Data/processed_Fiji_files/truth_classes_wcldMask_001.TIF'
Fiji = coralutils.CoralData(imgpath, Truthpath=truthpath, truth_key=json_data["Fiji_ClassDict"], load_type="raster")

Fiji.Consolidate_classes(newclassdict=json_data["L3_ClassDict"], transferdict=json_data["Fiji2L3_Dict"])
Fiji.Consolidate_classes(newclassdict=json_data["VedConsolidated_ClassDict"], transferdict=json_data["L32VedConsolidated_Dict"])
labelkey = {'Clouds': 0, 'No clouds': 1}

class_indices_rev = dict((v,k) for k,v in labelkey.items()) #flip k,v
print(class_indices_rev)

xstart = 0
xlen = 4864
ystart = 0
ylen = 4864

predict_size = 256
Fiji.testimage = Fiji.image[ystart:ystart+ylen,xstart:xstart+xlen,:]
Fiji.testimage = np.delete(Fiji.testimage, [0,3,5,7], 2) # harded coded for BGR + NIR
Fiji.testimage = np.dstack((Fiji.testimage[:,:,2::-1], Fiji.testimage[:,:,3]))

Fiji.set_mean_std(100,100)
print("Total Lines: ", Fiji.testimage.shape[0])
whole_predict, num_predict, prob_predict, truth_predict, accuracy = \
    Fiji.predict_on_whole_image(model, image_size = 256, num_classes = len(class_indices_rev), spacing = (64,64), predict_size = predict_size, num_lines=None, lastchannelremove=False)

  self.consolclass_weights = dict((k, (self.truthimage_consolidated.shape[0]*self.truthimage_consolidated.shape[1])/(self.truthimage_consolidated==newclassdict[k]).sum()) for k in newclassdict)


{0: 'Clouds', 1: 'No clouds'}
Total Lines:  1024
Line: 768 Col: 768/ 769 completed

In [85]:
patch_R = 255/200*Fiji.testimage[:,:,0]
patch_G = 255/200*Fiji.testimage[:,:,1]
patch_B = 255/200*Fiji.testimage[:,:,2]
patch_R[patch_R > 255] = 255
patch_G[patch_G > 255] = 255
patch_B[patch_B > 255] = 255

Fiji_RGB = np.rollaxis(np.asarray([patch_R, patch_G, patch_B], dtype=np.uint8),0,3)
# print(np.unique(Testtruth))

fig = plt.figure(3,figsize=(15,8))
ax1 = fig.add_subplot(1,2,1)
ax1.imshow(Fiji_RGB)

ax2 = fig.add_subplot(1,2,2)
ax2.imshow(whole_predict)
plt.show()


<IPython.core.display.Javascript object>