## Import libraries

In [2]:
import tensorflow as tf

In [3]:
from assimilation import assimilation
from stackage import stackage

In [4]:
import numpy as np

In [5]:
from sklearn.model_selection import train_test_split

## Functions

In [6]:
def reshaping(list_input):
    L = np.array(list_input)
    L = L.reshape(L.size,-1)
    
    return L

In [7]:
def normalise_01(image_data):
    stack = image_data
    image_data -= np.min(stack, axis=0)
    image_data /= (np.max(stack, axis=0) - np.min(stack, axis=0))
    return image_data

## Data preprocessing

In [8]:
import h5py

f = h5py.File('Images/big_raster/tr_ALLS1_patch/tr_Ymodel.hdf5', 'r')
S1_patch_id = np.array(f["id_sar"])
S2_patch_id = np.array(f["id_rgb"])
S1_patch_lulc = np.array(f["lulc_sar"])
S2_patch_lulc = np.array(f["lulc_rgb"])
S1_patch_path = f["patch_sar_path"]
S2_patch_path = f["patch_rgb_path"]
                         
f = h5py.File('Images/big_raster/g_ALLS1_patch/g_Ymodel.hdf5', 'r')
S1_grid_id = np.array(f["id_sar"])          
S2_grid_id = np.array(f["id_rgb"])                     
S1_grid_path = f["patch_sar_path"]
S2_grid_path = f["patch_rgb_path"]

In [9]:
stack = stackage(S2_grid_path)
S2_stack_grid = normalise_01(stack)

In [10]:
S2_patch_id = S2_patch_id.reshape(S2_patch_id.size,1)

S2_patch_lulc = S2_patch_lulc.reshape(S2_patch_lulc.size,1)

S2_grid_id = S2_grid_id.reshape(S2_grid_id.size,1)

### Dividing train and test patches

In [11]:
patch_id_train, patch_id_test, patch_lulc_train, patch_lulc_test = train_test_split(S2_patch_id,S2_patch_lulc, test_size=0.2, random_state=100)
patch_id_train = reshaping(patch_id_train)
patch_id_test = reshaping(patch_id_test)
patch_lulc_train = reshaping(patch_lulc_train)
patch_lulc_test = reshaping(patch_lulc_test)  

In [12]:
stack_train = assimilation(patch_id_train, patch_id_test, patch_lulc_test,S2_patch_path)[0]
stack_test = assimilation(patch_id_train, patch_id_test, patch_lulc_test,S2_patch_path)[1]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[ 714.  544.  346. 1952.]
b'Images/big_raster/tr_RBGNIR_patch/tr_RBGNIR_14054_0.tif'
[14054]
[ 304.  456.  223. 2378.]
b'Images/big_raster/tr_RBGNIR_patch/tr_RBGNIR_23446_0.tif'
[23446]
[ 855.  719.  492. 2412.]
b'Images/big_raster/tr_RBGNIR_patch/tr_RBGNIR_104605_0.tif'
[104605]


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[1452. 1236.  951. 1761.]
b'Images/big_raster/tr_RBGNIR_patch/tr_RBGNIR_63863_1.tif'
[63863]
[ 869.  693.  466. 1571.]
b'Images/big_raster/tr_RBGNIR_patch/tr_RBGNIR_62858_0.tif'
[62858]
[1290.  925.  632. 2426.]
b'Images/big_raster/tr_RBGNIR_patch/tr_RBGNIR_87427_0.tif'
[87427]


## Model

https://becominghuman.ai/understanding-and-coding-inception-module-in-keras-eb56e9056b4b

https://faroit.com/keras-docs/1.1.1/applications/#inceptionv3

In [None]:

inputs =tf.keras.Input(shape=(10,10,3))


base_model = tf.keras.applications.inception_v3.InceptionV3(include_top=True, weights='imagenet', input_tensor=None)



# first: train only the top layers (which were randomly initialized)
# i.e. freeze all convolutional InceptionV3 layers
for layer in base_model.layers:
    layer.trainable = False
    
# compile the model (should be done *after* setting layers to non-trainable)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')


history = model.fit(stack_train,
          patch_lulc_train, 
          validation_data=(stack_test, patch_lulc_test),
          epochs=100,
          batch_size=batch_size)

In [None]:
predict = model.predict(S2_stack_grid)

In [None]:
import os, signal

os.kill(os.getpid() , signal.SIGKILL)