In [None]:
from __future__ import division, print_function
import sys
import os
import h5py
import matplotlib.pyplot as plt
import numpy as np
from osgeo import gdal
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dropout,ZeroPadding2D, Dropout,Concatenate,Conv2DTranspose,UpSampling2D
from tensorflow.keras.layers import Activation, Reshape
from tensorflow.keras.layers import Convolution2D, Conv2D,MaxPooling2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from numpy.random import seed
import time
seed(1234)
#tf.test.is_gpu_available()
#tf.config.list_physical_devices('CPU')

#from tensorflow import set_random_seed
#set_random_seed(1234)
tf.random.set_seed(1234)

## data utility functions
def to_categorical_4d(y, nc):
    """Convert a reclassed ground truth array to one-hot encoding

    Keyword arguments:
    y -- ground truth image
    nc -- number of classes
    """
    Y = np.zeros((y.shape[0],
                  nc,
                  y.shape[2],
                  y.shape[3]),
                  dtype=np.int32)
    y=y.astype(np.int32)

    for h in range(y.shape[0]):
        for i in range(y.shape[2]):
            for j in range(y.shape[3]):

                if y[h, 0, i, j] != 0:
                    Y[h, y[h, 0, i, j]-1, i, j] = 1
                else:
                    continue

    return Y

def swap_arr(arr):
    """
    #swap axes to that of tf backend (channels last)
    """
    val_1=np.swapaxes(arr,1,2)
    val_2=np.swapaxes(val_1,2,3)
    return val_2


def mix_params(params):
    """Create a list of possible parameter combinations.

    Keyword arguments:
    params -- the group of parameter values to combine.
    """
    paramset = []
    for paramkey in params:
        _ = []
        for i in range(len(params[paramkey])):
            _.append({paramkey: params[paramkey][i]})
        paramset.append(_)

    paramset = list(itertools.product(*paramset))

    finalparams = []
    for params in paramset:
        _ = dict()
        for param in params:
            if len(_) == 0:
                _ = param
            else:
                __ = _.copy()
                __.update(param)
                _ = __
        finalparams.append(_)

    return finalparams  

def make_dir_paths(mypath):
    if not os.path.isdir(mypath):
        os.makedirs(mypath)
        print("created path")

#fcn atrous ariginal  
def add_common_layers(y):
    """
    add relu activation and batch normalization
    """
    y = Activation ('relu')(y)
    y = BatchNormalization()(y) #should i define the axis for BN? what is the effect of not specifying the BN
    return y

def fcn_atr_orig(psize, nc, nb, weights_path=None):
    """
    original fcn atrous architecture
    weights_path--for pre-trainng, provide weights through the weights_path argument
    psize--patch size
    nc--number of classes
    nb--number of input channels

    """

    inp=Input(shape=(psize, psize,nb))

    x_0 = ZeroPadding2D((4, 4))(inp)
    x_0 = Conv2D(64, (5, 5), padding = 'valid', dilation_rate = (2,2))(x_0)
    x_0 = add_common_layers(x_0)
    x_0 = ZeroPadding2D((1, 1))(x_0)
    x_0 = MaxPooling2D(pool_size=(3, 3),strides=1)(x_0)

    x_1 = ZeroPadding2D((4, 4))(x_0)
    x_1 = Conv2D(64, (5, 5), padding = 'valid', dilation_rate = (2,2))(x_1)
    x_1 = add_common_layers(x_1)
    x_1 = ZeroPadding2D((1, 1))(x_1)
    x_1 = MaxPooling2D(pool_size=(3, 3),strides=1)(x_1)

    x_2 = ZeroPadding2D((2, 2))(x_1)
    x_2 = Conv2D(64, (3, 3), padding = 'valid', dilation_rate = (2,2))(x_2)
    x_2 = add_common_layers(x_2)
    x_2 = ZeroPadding2D((1, 1))(x_2)
    x_2 = MaxPooling2D(pool_size=(3, 3),strides=1)(x_2)

    x_3 = ZeroPadding2D((2, 2))(x_2)
    x_3 = Conv2D(64, (3, 3), padding = 'valid', dilation_rate = (2,2))(x_3)
    x_3 = add_common_layers(x_3)
    x_3 = ZeroPadding2D((1, 1))(x_3)
    x_3 = MaxPooling2D(pool_size=(3, 3),strides=1)(x_3)

    x_4 = ZeroPadding2D((2, 2))(x_3)
    x_4 = Conv2D(64, (3, 3), padding = 'valid', dilation_rate = (2,2))(x_4)
    x_4 = add_common_layers(x_4)
    x_4 = ZeroPadding2D((1, 1))(x_4)
    x_4 = MaxPooling2D(pool_size=(3, 3),strides=1)(x_4)

    x_5 = ZeroPadding2D((2, 2))(x_4)
    x_5 = Conv2D(64, (3, 3), padding = 'valid', dilation_rate = (2,2))(x_5)
    x_5 = add_common_layers(x_5)
    x_5 = ZeroPadding2D((1, 1))(x_5)
    x_5 = MaxPooling2D(pool_size=(3, 3),strides=1)(x_5)

    #concatenation layer
    xc6= Concatenate(axis = 3)([x_0,x_1,x_2,x_3,x_4,x_5])

    x7=Conv2D(nc, (1, 1))(xc6)
    out_p = Activation("softmax")(x7)

    model = Model(inputs=inp,outputs=out_p)

    if weights_path:
        model.load_weights(weights_path)

    return model

def unet(psize, nc, nb, pretrained_weights = None):
    #original network 
    #using upsampling lauyers
    #without dropout layers
    #with batch normalization layers

    input_tensor = Input(shape=(psize, psize,nb))

    conv1 = Conv2D(64, (3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(input_tensor)
    conv1 = Conv2D(64, (3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
    conv1 = BatchNormalization()(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, (3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
    conv2 = Conv2D(128, (3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
    conv2 = BatchNormalization()(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, (3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
    conv3 = Conv2D(256, (3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
    conv3 = BatchNormalization()(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, (3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
    conv4 = Conv2D(512, (3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
    conv4 = BatchNormalization()(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(1024, (3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
    conv5 = Conv2D(1024, (3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
    #conv1 128, conv2 64, conv3 32, conv4 16 conv5 8
    #transpose convolutions
    #input_shape=()
    up6 = Conv2DTranspose(512, (3,3), strides=(2,2),activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)

    #merge6 = Concatenate(axis = 3)([up6,conv4])
    merge6 = Concatenate(axis = 3)([conv4,up6])
    conv6 = Conv2D(512, (3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
    conv6 = Conv2D(512, (3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)

    up7 = Conv2DTranspose(256, (3,3),strides=(2,2), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)
    merge7 = Concatenate(axis = 3)([conv3,up7])
    conv7 = Conv2D(256, (3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
    conv7 = Conv2D(256, (3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)

    up8 = Conv2DTranspose(128, (3,3), strides=(2,2), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
    merge8 = Concatenate(axis = 3)([conv2,up8])
    conv8 = Conv2D(128, (3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
    conv8 = Conv2D(128, (3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)

    up9 = Conv2DTranspose(64, (3,3), strides=(2,2), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
    merge9 = Concatenate(axis = 3)([conv1,up9])
    conv9 = Conv2D(64, (3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
    conv9 = Conv2D(64, (3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)

    conv9 = Conv2D(nc, (1,1), activation = 'softmax', padding = 'same', kernel_initializer = 'he_normal')(conv9)

    model = Model(inputs = input_tensor, outputs = conv9)

    if (pretrained_weights):
        model.load_weights(pretrained_weights)

    return model

# paths for loading the training data

EXPT_ID="128_5epochs_UNET" #unique id for your experiment

root_path="E:/ACADEMICS/DEEP_LEARNING_TUTORIAL_PASTECA/DEEP_LEARNING_TUTORIAL_PASTECA/GOMA_1947"
patch_size = 128
nc=6
nb = 1 # the number of channels of the input image, for panchromatic band, it is 1, for rgb it is 3
nb_epochs = 2 # number of epochs for training the model

logs_folder=root_path+"/RESULTS/"+EXPT_ID+"/FIGURES_TIME_LOGS"
make_dir_paths(logs_folder)

weights_folder = root_path+"/RESULTS/"+EXPT_ID+"/weights"
make_dir_paths(weights_folder)

w8dir = weights_folder + "/saved_weights.hdf5"

### load the training data
t0 = time.time()

#curated data
"adding path to the gray images dataset"
smpldir=root_path+"/samples/training_goma.hdf5"

#reading the data
with h5py.File(smpldir, "r") as f:
    x_train = np.asarray(f["X_train"][:1000,:,:,:])
    x_val = np.asarray(f["X_val"][:100,:,:,:])
    y_train = np.asarray(f["y_train"][:1000,:,:,:])
    y_val = np.asarray(f["y_val"][:100,:,:,:])

t1 = time.time()
print("Finished importing data after %.2f mins" % ((t1-t0)/60.0)) 

# convert to onehot encoding
t0 = time.time()

y_train_u= to_categorical_4d(y_train[:,:,:128,:128],nc)
y_val_u= to_categorical_4d(y_val[:,:,:128,:128],nc)

#reshape to channels last ordering
y_train_resh = swap_arr(y_train_u)
y_val_resh = swap_arr(y_val_u)

#reshaping the x patches too
x_train_resh=swap_arr(x_train[:,:,:128,:128])
x_val_resh=swap_arr(x_val[:,:,:128,:128])

t1 = time.time()
print("Finished preparing data after %.2f mins" % ((t1-t0)/60.0)) 

#SAVE A PLOT OF THE CLASS FREQUENCY AND THE CLASS FREQUENCY
txtcontent=""

np.unique(y_train, return_counts=True)
messagetoprint="class distributions:\n%s" %list(np.unique(y_train, return_counts=True))
txtcontent+=messagetoprint+"\n\n"


#plot the class distributions
unique, counts = np.unique(y_train, return_counts=True)
plt.bar(unique, counts)
unique, counts = np.unique(y_val, return_counts=True)
plt.bar(unique, counts)

plt.title('Class Frequency')
plt.xlabel('Class')
plt.ylabel('Frequency')

plt.savefig(logs_folder+"/class_frequency.png")
plt.show()

##CLEAR THE MEMORY OF Y_TRAIN,Y_VAL,X_VAL AND X_TRAIN

#clear memory of the raw training data
del x_train
del y_train
del x_val
del y_val
del y_train_u
del y_val_u

# the number of pixels per class in your training set
#save to csv
#np.unique(y_train, return_counts=True)

# train the model by loading the training data in batches
#no data augmentation
#datagen.fit(x_train)

bestparams = {'lrate': 0.1, 'momentum': 0.8, 'lrdecay': 0.001}

t0=time.time()

sgd=SGD(lr=bestparams['lrate'],
            decay=bestparams['lrdecay'], 
            momentum=bestparams['momentum'],
            nesterov=True)

#https://www.tensorflow.org/guide/keras/train_and_evaluate
# training dataset
train_dataset = tf.data.Dataset.from_tensor_slices((x_train_resh, y_train_resh))
train_dataset = train_dataset.shuffle(buffer_size=100).batch(16)

# validation dataset
val_dataset = tf.data.Dataset.from_tensor_slices((x_val_resh, y_val_resh))
val_dataset = val_dataset.batch(16)


#model_unet=unet(patch_size,nc,nb)
model_atr=fcn_atr_orig(patch_size,nc,nb)

#in tf2, categorical crossentropy can be computed over multiple dimensions

cce = tf.keras.losses.CategoricalCrossentropy()
cca = tf.keras.metrics.CategoricalAccuracy()
#model_unet.summary() #to view the network parameters 

model_atr.compile(loss=cce,
                  optimizer=sgd,
                  metrics=[cca])

history = model_atr.fit(train_dataset,
                         epochs = 200,
                        validation_data=val_dataset)

#Save the weights
model_atr.save_weights(w8dir)

t1=time.time()

print("Finished training after %.2f mins" % ((t1-t0)/60.0))

messagetoprint="Finished training after %.2f mins" % ((t1-t0)/60.0)
#print(messagetoprint)
txtcontent+=messagetoprint+"\n"

"visualize the learned weights"
#summarize history for accuracy

plt.plot(history.history['categorical_accuracy'])
plt.plot(history.history['val_categorical_accuracy'])
plt.title('model accuracy')
plt.ylabel ('accuracy')
plt.xlabel('epoch')
plt.axis([-1,200,0.50,0.95])
plt.legend(['train','test'], loc = 'upper left')
plt.savefig(logs_folder+"/model_accuracy.png")

#visualise and close the figure
#plt.show()
#plt.clf()


#summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.axis([-1,200,0.0,1])
plt.legend(['train','test'], loc='upper left')
plt.savefig(logs_folder+"/model_loss.png")

#visualise and close the figure
#plt.show()
#plt.clf()


#CLEAR THE MEMORY OF THE TRAINING AND TEST DATA
del x_train_resh
del y_train_resh
del x_val_resh
del y_val_resh



In [6]:
############################################test the model#####################################
##IMAGE TESTING

## functions for testing 
import subprocess, glob
from random import seed
from sklearn import metrics
from sklearn.metrics import confusion_matrix
import random
import itertools
seed(1234)

BATCH_SIZE=16
WINDOW_SIZE=[256,256]
#predictions=root_path+"/preds/ATR_128/MINI_TILES/"
raw_files_pan = glob.glob(root_path+'/raw_tif_totest/*.tif')
mini_preds_folder = root_path+"/RESULTS/"+EXPT_ID+"/predictions/MINI_TILES"
probs_folder = root_path+"/RESULTS/"+EXPT_ID+"/probabilities/WHOLE_TILE_PROBS"

make_dir_paths(probs_folder)
make_dir_paths(mini_preds_folder)

w8_fname = w8dir
patch_size=256
STRIDE=228
nc=6

#tiles to test
gmax=63450
gmin=0

## functions for loading the data
def nulltozero(arr):
    arrcopy = np.copy(arr)
    low_values_flags = arrcopy < 0 #0.7,0  
    arrcopy[low_values_flags]=0
    return arrcopy
#functions for loading the images and converting them to arrays

def img_to_array(*images):
    """Convert an image or list of images to numpy arrays.

    Keyword arguments:
    *images -- list containing the images to be converted
    """
    imgarrays = []
    i = 0
    for img in images:
        arr = gtiff_to_array(img)
        imgarrays.append(arr)
    return imgarrays


def gtiff_to_array(imgfname):                                      
    """Transform a geotiff to numpy array.

    Keyword arguments:
    imgfnames -- filename of image to convert
    """
    ds = gdal.Open(imgfname)
    for band in range(ds.RasterCount):
        band += 1
        if band == 1:
            arr = np.array(ds.GetRasterBand(band).ReadAsArray())
            arr = np.expand_dims(arr, axis=2)
        else:
            concat = np.array(ds.GetRasterBand(band).ReadAsArray())
            concat = np.expand_dims(concat, axis=2)
            arr = np.concatenate((arr,
                                  concat),
                                 axis=2)
    return arr

def reclassgts2(gtsarray):
    """Reclassify ground truth dataset array to single class numbers.

    Keyword arguments:
    gtsarray -- the ground truth dataset array
    """
    reclassarray = np.zeros(shape=(gtsarray.shape[0], gtsarray.shape[1]),
                            dtype=np.uint8)
    cnum = 1
    mask = np.logical_or(np.logical_or(gtsarray[:,:,0]==1,gtsarray[:,:,0]==3),gtsarray[:,:,0]==4)
    reclassarray[mask]=cnum

    cnum2 = 2
    mask = np.logical_or(gtsarray[:,:,0]==2,gtsarray[:,:,0]==2)
    reclassarray[mask] = cnum2
    return reclassarray

#create idxarray for the creation of the map

def sample_idx(arr):
    """Randomly sample an array stratified based on frequency.

    Keyword arguments:
    arr -- the array being sampled
    cratios -- the representative fractions of each classes
    n -- total number of samples (default 1000)
    """
    i = 0
    arr_copy = np.copy(arr)
    idxarray = np.zeros(shape=(0, 2), dtype=np.int16)
    nc = np.array(np.where(arr_copy >0)).T.shape[0]
    arr_flat = arr_copy.flatten()
    n = arr_flat.shape[0]
    randidx = np.asarray(range(n),dtype= np.int32)
    #randidx=randidx.astype(np.int32)
    idxarray = np.array(np.where(arr_copy >= 0)).T[randidx, :]
    #sampleidx += csamples

    del arr_copy
    return idxarray

def write_geotiff(fname, data, geo_transform, projection):
    """Create a GeoTIFF file with the given data."""
    driver = gdal.GetDriverByName('GTiff')
    rows, cols = data.shape
    dataset = driver.Create(fname, cols, rows, 1, gdal.GDT_Byte)
    dataset.SetGeoTransform(geo_transform)
    dataset.SetProjection(projection)
    band = dataset.GetRasterBand(1)
    band.WriteArray(data)
    dataset = None  # Close the file

def swap_axes(arr):
    return np.expand_dims(arr, axis = 0)

##attempting to perform classification by loading the entire tiles

#remove nan
def nulltozero_g(arr):
    arrcopy = np.copy(arr)
    low_values_flags = (np.isnan(arr))#0.7,0
    arrcopy[low_values_flags]=0
    return arrcopy

def extract_geometry(path_dir):
    raster_dataset = gdal.Open(path_dir, gdal.GA_ReadOnly)
    geo_transform = raster_dataset.GetGeoTransform()

    return geo_transform

# Utils adapting the prediction scheme in https://github.com/nshaud/DeepNetsForEO/blob/master/SegNet_PyTorch_v2.ipynb
#https://machinelearningmastery.com/reproducible-results-neural-networks-keras/


def get_random_pos(img, window_shape):
    """ Extract of 2D random patch of shape window_shape in the image """
    w, h = window_shape
    W, H = img.shape[-2:]
    x1 = random.randint(0, W - w - 1)
    x2 = x1 + w
    y1 = random.randint(0, H - h - 1)
    y2 = y1 + h
    return x1, x2, y1, y2

def accuracy(input, target):
    return 100 * float(np.count_nonzero(input == target)) / target.size

def sliding_window(top, step=10, window_size=(32,32)):
    """ Slide a window_shape window across the image with a stride of step """
    for x in range(0, top.shape[0], step):
        if x + window_size[0] > top.shape[0]:
            x = top.shape[0] - window_size[0]
        for y in range(0, top.shape[1], step):
            if y + window_size[1] > top.shape[1]:
                y = top.shape[1] - window_size[1]
            yield x, y, window_size[0], window_size[1]

def count_sliding_window(top, step=10, window_size=(32,32)):
    """ Count the number of windows in an image """
    c = 0
    for x in range(0, top.shape[0], step):
        if x + window_size[0] > top.shape[0]:
            x = top.shape[0] - window_size[0]
        for y in range(0, top.shape[1], step):
            if y + window_size[1] > top.shape[1]:
                y = top.shape[1] - window_size[1]
            c += 1
    return c

def grouper(n, iterable):
    """ Browse an iterator by chunk of n elements """
    it = iter(iterable)
    while True:
        chunk = tuple(itertools.islice(it, n))
        if not chunk:
            return
        yield chunk

# enable the testing of the entire test set
def test(net, img_arr,N_CLASSES,stride=WINDOW_SIZE[0], batch_size=BATCH_SIZE, window_size=WINDOW_SIZE):
    
    # Use the network on the test set
    all_preds = []
    all_probs=[]
    all_gts = []
    
    for img in img_arr:
        #print (img.shape)
        pred = np.zeros(img.shape[:2] + (N_CLASSES,))
        prob = np.zeros(img.shape[:2] + (N_CLASSES,))
        for i, coords in enumerate(grouper(batch_size, sliding_window(img, step=stride, window_size=window_size))):
            image_patches_1=[]
            image_patches = [np.copy(np.expand_dims(img[x:x+w, y:y+h],axis=0)) for x,y,w,h in coords]
            image_patches_1= np.concatenate(image_patches, axis=0)
            # Do the inference
            outs= net.predict(image_patches_1)
            del image_patches # clear mem

            # Fill in the results array
            for out, (x, y, w, h) in zip(outs, coords):
                #out = out.transpose((1,2,0))
                #pred[x:x+w, y:y+h] += out  #this creates addition in the overlap regions

                pred[x:x+w, y:y+h] = out
            del(outs)
        "use the np max for max probability"
        prob = np.max(pred, axis=2)
        pred = np.argmax(pred, axis=2)
        
        all_preds.append(pred)
        all_probs.append(prob)

    return all_preds,all_probs

#Reclassifying the array and giving it coordinates
def reclass_gts(gtsarray):
    """Reclassify ground truth dataset array to single class numbers.

    Keyword arguments:
    gtsarray -- the ground truth dataset array
    """
    reclassarray = np.zeros(shape=(gtsarray.shape[0], gtsarray.shape[1]),
                            dtype=np.uint8)
    cnum = 1
    for color in _ccolors:
        mask = np.logical_and(np.logical_and(gtsarray[:, :, 0] == color[0],
                              gtsarray[:, :, 1] == color[1]),
                              gtsarray[:, :, 2] == color[2])
        reclassarray[mask] = cnum
        cnum += 1
    return reclassarray

def save_to_1band(arr,out_path,geom,proj):
    #save a one band image
    #p_rec is the reclassified array and has two dimensions
    #pred_map_GCs2 = predsdir + "trainingset_3/georeferenced/fcn_atr_goma_1947_clip1_1bands.tif"
    arr=np.expand_dims(arr, axis=2)
    nrows,ncols,nbands = arr.shape[0],arr.shape[1],arr.shape[2]
    driver = gdal.GetDriverByName('GTiff')
    nw_ds = driver.Create(out_path, ncols, nrows, nbands, gdal.GDT_UInt32 )
    nw_ds.SetGeoTransform(geom)
    nw_ds.SetProjection(proj)

    nw_ds.GetRasterBand(1).WriteArray(arr[:, :,0])

    #for i in range(nbands):
    #	nw_ds.GetRasterBand(i+1).WriteArray(P_rec[:, :, i])

    nw_ds = None

def save_prob_1band(arr,out_path,geom,proj):
    #save a one band image
    #p_rec is the reclassified array and has two dimensions
    #pred_map_GCs2 = predsdir + "trainingset_3/georeferenced/fcn_atr_goma_1947_clip1_1bands.tif"
    arr=np.expand_dims(arr, axis=2)
    nrows,ncols,nbands = arr.shape[0],arr.shape[1],arr.shape[2]
    driver = gdal.GetDriverByName('GTiff')
    nw_ds = driver.Create(out_path, ncols, nrows, nbands, gdal.GDT_Float32 )
    nw_ds.SetGeoTransform(geom)
    nw_ds.SetProjection(proj)

    nw_ds.GetRasterBand(1).WriteArray(arr[:, :,0])

    #for i in range(nbands):
    #	nw_ds.GetRasterBand(i+1).WriteArray(P_rec[:, :, i])

    nw_ds = None

# normalize the top array
def norm_rgbn(data,gmax,gmin):
    """
    rexcale  the multispectral data [0,1]
    data--the multispectral array
    """
    data = data.astype(float)
    data_norm = (data - gmin)/(gmax - gmin)
    return data_norm

def data_proj(arr_path):
    raster_dataset = gdal.Open(arr_path, gdal.GA_ReadOnly)
    geo_transform = raster_dataset.GetGeoTransform()
    proj = raster_dataset.GetProjectionRef()
    return geo_transform,proj


#call the trained model
t0 = time.time()

#model_fcn=fcn_atr_orig(patch_size,nc,1)
model_unet=unet(patch_size,nc,nb)

#load the weights
model_unet.load_weights(w8_fname)
for rgbn in raw_files_pan:
    rgbn_raw=img_to_array(rgbn)
    geo_info,projec_info=data_proj(rgbn)

    #normalize the rgb
    rgbn_norm=norm_rgbn(rgbn_raw[0],gmax,gmin)
    del rgbn_raw

    allpred= test(model_unet,[rgbn_norm],nc,stride=STRIDE)

    file_name=os.path.split(rgbn)[-1]
    save_to_1band(allpred[0][0]+1,mini_preds_folder+"/"+file_name,geo_info,projec_info)
    save_prob_1band(allpred[1][0],probs_folder+"/"+file_name,geo_info,projec_info)


t1 = time.time()
print("Finished testing after %.2f mins" % ((t1-t0)/60.0))

#txtcontent=""

messagetoprint="Finished testing after %.2f mins" % ((t1-t0)/60.0)
#print(messagetoprint)
txtcontent+=messagetoprint+"\n\n"

f = open(logs_folder+"/model_training_testing.txt", 'w')
f.write(EXPT_ID+" time_log"+"\n\n")
f.write(txtcontent)
f.close()

created path
Finished testing after 0.06 mins


In [None]:
#CLASSIFY THE ENTIRE STUDY AREA
#This is the path to the image covering your study area
raw_files_large=["H:/PASTECA/DATA/GOMA_1947/Goma1947_UTM35S_ortho_1m_test.tif"]
t0 = time.time()

large_preds_folder = root_path+"/RESULTS/"+EXPT_ID+"/predictions/WHOLE_TILE"
probs_folder = root_path+"/RESULTS/"+EXPT_ID+"/probabilities/WHOLE_TILE_PROBS"
make_dir_paths(large_preds_folder)
make_dir_paths(probs_folder)


t0 = time.time()
for rgbn in raw_files_large:
    rgbn_raw=img_to_array(rgbn)
    geo_info,projec_info=data_proj(rgbn)

    #normalize the rgb
    rgbn_norm=norm_rgbn(rgbn_raw[0],gmax,gmin)
    del rgbn_raw

    allpred= test(model_unet,[rgbn_norm],nc,stride=228)

    file_name=os.path.split(rgbn)[-1]
    save_to_1band(allpred[0][0]+1,large_preds_folder+"/"+file_name,geo_info,projec_info)
    save_prob_1band(allpred[1][0],probs_folder+"/"+file_name,geo_info,projec_info)

t1 = time.time()
print("Finished testing whole tile after %.2f mins" % ((t1-t0)/60.0))
txtcontent=""
messagetoprint="Finished testing whole tile after %.2f mins" % ((t1-t0)/60.0)
#print(messagetoprint)
txtcontent+=messagetoprint+"\n\n"

f = open(logs_folder+"/testing_large_file.txt", 'w')
f.write(EXPT_ID+" testing_time"+"\n\n")
f.write(txtcontent)
f.close()