[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adugnag/deSpeckNet-TF-GEE/blob/main/notebooks/train.ipynb)

# Setup software libraries



In [None]:
# Cloud authentication.
from google.colab import auth
auth.authenticate_user()

In [None]:
import tensorflow as tf
import numpy as np
import os

#tf.enable_eager_execution()
print(tf.__version__)

In [None]:
#@title Helper functions
#simple data augmentation
class dataAugment(tf.keras.layers.Layer):
  def __init__(self, seed=42):
    super().__init__()
    # both use the same seed, so they'll make the same random changes.
    self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
    self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
    self.augment_masks = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)

  def call(self, inputs, labels, masks):
    inputs = self.augment_inputs(inputs)
    labels = self.augment_labels(labels)
    masks = self.augment_masks(masks)
    return inputs, (labels, inputs), masks


def parse_tfrecord(example_proto):
  return tf.io.parse_single_example(example_proto, FEATURES_DICT)


def to_tuple_train(inputs):
  inputsList = [inputs.get(key) for key in FEATURES]
  stacked = tf.stack(inputsList, axis=0)
  stacked = tf.transpose(stacked, [1, 2, 0])
  #select features
  data = stacked[:,:,:len(params['BANDS'])]
  #select labels
  if len(params['BANDS']) ==2:
      label = stacked[:,:,len(params['BANDS']):len(params['BANDS'])+2]
      masks = stacked[:,:,len(params['BANDS'])+2:len(params['BANDS'])+3]
  else:
      label = stacked[:,:,len(params['BANDS']):len(params['BANDS'])+1]
      masks = stacked[:,:,len(params['BANDS'])+1:len(params['BANDS'])+2]
  return data, label, masks

def to_tuple_tune(inputs):
  inputsList = [inputs.get(key) for key in FEATURES]
  stacked = tf.stack(inputsList, axis=0)
  stacked = tf.transpose(stacked, [1, 2, 0])
  data = stacked[:,:,:len(params['BANDS'])]
  #select features
  label = stacked[:,:,len(params['BANDS']):]
  return data, (label, data)

def get_dataset(pattern, params):
  glob = tf.io.gfile.glob(pattern)
  #glob =tf.compat.v1.gfile.Glob(pattern)
  dataset = tf.data.TFRecordDataset(glob, compression_type='GZIP')
  dataset = dataset.map(parse_tfrecord, num_parallel_calls=5)
  if params['MODE'] == 'training':
      dataset = dataset.map(to_tuple_train, num_parallel_calls=5)
      dataset = dataset.map(dataAugment(), num_parallel_calls=5)
  else:
      dataset = dataset.map(to_tuple_tune, num_parallel_calls=5)
  return dataset


"""# Training data: use the tf.data api to build our data pipeline"""
def get_training_dataset(params, FEATURES, FEATURES_DICT):
    global params
    global FEATURES
    global FEATURES_DICT
    if params['EXPORT'] == 'GCS':
        glob = 'gs://' + params['BUCKET'] + '/' + params['FOLDER'] + '/' + params['TRAINING_BASE'] + '*'
    else:
        glob = params['DRIVE'] + '/' + params['FOLDER'] + '/' + params['TRAINING_BASE'] + '*'
    dataset = get_dataset(glob,params)
    dataset = dataset.shuffle(params['BUFFER_SIZE']).batch(params['BATCH_SIZE']).repeat()
    return dataset


def get_eval_dataset(params):
    if params['EXPORT'] == 'GCS':
        glob = 'gs://' + params['BUCKET'] + '/' + params['FOLDER'] + '/' + params['EVAL_BASE'] + '*'
    else:
        glob = params['DRIVE'] + '/' + params['FOLDER'] + '/' + params['EVAL_BASE'] + '*'
    dataset = get_dataset(glob,params)
    dataset = dataset.batch(1).repeat()
    return dataset


###########################################
# 4. MODEL
###########################################

def deSpeckNet(depth,filters,image_channels, use_bnorm=True):
    layer_count = 0
    inpt = tf.keras.layers.Input(shape=(None,None,image_channels),name = 'input'+str(layer_count))
    # 1st layer, Conv+relu
    layer_count += 1
    x0 = tf.keras.layers.Conv2D(filters=filters, kernel_size=(3,3), strides=(1,1),kernel_initializer='glorot_normal', padding='same',use_bias = True,name = 'conv'+str(layer_count))(inpt)
    layer_count += 1
    x0 = tf.keras.layers.Activation('relu',name = 'relu'+str(layer_count))(x0)
    # depth-2 layers, Conv+BN+relu
    for i in range(depth-2):
        layer_count += 1
        x0 = tf.keras.layers.Conv2D(filters=filters, kernel_size=(3,3), strides=(1,1),kernel_initializer='glorot_normal', padding='same',use_bias = True,name = 'conv'+str(layer_count))(x0)
        if use_bnorm:
            layer_count += 1
        x0 = tf.keras.layers.BatchNormalization(axis=3, momentum=0.0,epsilon=0.0001, name = 'bn'+str(layer_count))(x0)
        layer_count += 1
        x0 = tf.keras.layers.Activation('relu',name = 'relu'+str(layer_count))(x0)  
    # last layer, Conv
    layer_count += 1
    x0 = tf.keras.layers.Conv2D(filters=image_channels, kernel_size=(3,3), strides=(1,1), kernel_initializer='glorot_normal',padding='same',use_bias = True,name = 'speckle'+str(1))(x0)
    layer_count += 1
    
    
    x = tf.keras.layers.Conv2D(filters=filters, kernel_size=(3,3), strides=(1,1),kernel_initializer='glorot_normal', padding='same',use_bias = True,name = 'conv'+str(layer_count))(inpt)
    layer_count += 1
    x = tf.keras.layers.Activation('relu',name = 'relu'+str(layer_count))(x)
    # depth-2 layers, Conv+BN+relu
    for i in range(depth-2):
        layer_count += 1
        x = tf.keras.layers.Conv2D(filters=filters, kernel_size=(3,3), strides=(1,1),kernel_initializer='glorot_normal', padding='same',use_bias = True,name = 'conv'+str(layer_count))(x)
        if use_bnorm:
            layer_count += 1
        x = tf.keras.layers.BatchNormalization(axis=3, momentum=0.0,epsilon=0.0001, name = 'bn'+str(layer_count))(x)
        layer_count += 1
        x = tf.keras.layers.Activation('relu',name = 'relu'+str(layer_count))(x)  
    # last layer, Conv
    layer_count += 1
    x = tf.keras.layers.Conv2D(filters=image_channels, kernel_size=(3,3), strides=(1,1), kernel_initializer='glorot_normal',padding='same',use_bias = True,name = 'clean' + str(1))(x)
    layer_count += 1
    x_orig = tf.keras.layers.Add(name = 'noisy' +  str(1))([x0,x])
    
    model = tf.keras.Model(inputs=inpt, outputs=[x,x_orig])
    
    return model

#Learning rate scheduler
def lr_schedule(epoch):
    initial_lr = 1e-3
    if epoch<=30:
        lr = initial_lr
    elif epoch<=60:
        lr = initial_lr/10
    elif epoch<=80:
        lr = initial_lr/20 
    else:
        lr = initial_lr/20 
    tf.summary.scalar('learning rate', data=lr, step=epoch)
    return lr

#Total variation loss
def TVloss(y_true, y_pred):
  return tf.reduce_sum(tf.image.total_variation(y_pred))


# Setup parameters

In [None]:
#Parameters
params = {   # GCS bucket
            'EXPORT': 'GCS',
            'BUCKET' : 'senalerts_dl3',
            'DRIVE' : '/content/drive',
            'FOLDER' : 'deSpeckNet',
            'TRAINING_BASE' : 'training_deSpeckNet_DUAL_Median_mask_tune',
            'EVAL_BASE' : 'eval_deSpeckNet_DUAL_median_mask_tune',
            'MODE' : 'tuning',
          # Should be the same bands selected during data prep
            'BANDS': ['VV', 'VH'],
            'RESPONSE_TR' : ['VV_median', 'VH_median'],
            'RESPONSE_TU' : ['VV', 'VH'],
            'MASK' : ['VV_mask', 'VH_mask'],
            'KERNEL_SIZE' : 40,
            'KERNEL_SHAPE' : [40, 40],
            'KERNEL_BUFFER' : [20, 20],
          # Specify model training parameters.
            'BATCH_SIZE' : 16,
            'TRAIN_SIZE':32000,
            'EVAL_SIZE':8000,
            'EPOCHS' : 50,
            'BUFFER_SIZE': 2000,
            'TV_LOSS' : False,
            'DEPTH' : 17,
            'FILTERS' : 64,
            'MODEL_NAME': 'model_deSpeckNet_DUAL_aug_mask_v1'
            }


if params['MODE'] == 'training':
  FEATURES = params['BANDS'] + params['RESPONSE_TR'] + params['MASK']
  BUFFER_SIZE = params['BUFFER_SIZE']
  TRAIN_SIZE = params['TRAIN_SIZE']
  VALIDATION_SIZE = params['EVAL_SIZE']
  EPOCH = params['EPOCH']
else:
  FEATURES = params['BANDS']  + params['RESPONSE_TU']
  BUFFER_SIZE = 500
  TRAIN_SIZE = 4000
  VALIDATION_SIZE = 1000
  EPOCHS = 1
    
# Specify the size and shape of patches expected by the model.
KERNEL_SHAPE = [params['KERNEL_SIZE'], params['KERNEL_SIZE']]

COLUMNS = [tf.io.FixedLenFeature(shape=KERNEL_SHAPE, dtype=tf.float32) for k in FEATURES]
FEATURES_DICT = dict(zip(FEATURES, COLUMNS))

IMAGE_CHANNELS = len(params['BANDS'])

if params['EXPORT'] == 'GCS':
    MODEL_DIR = 'gs://' + params['BUCKET'] + '/' + params['FOLDER'] + '/' + params['MODEL_NAME']
else:
    MODEL_DIR = params['DRIVE'] + '/' + params['FOLDER'] + '/' + params['MODEL_NAME']


# Training data


In [None]:
#Use the tf.data api to build our data pipeline
training = get_training_dataset(params)
evaluation = get_eval_dataset(params)

print(iter(training.take(1)).next())

# Build Model



In [None]:
model = deSpeckNet(depth=params['DEPTH'],filters=params['FILTERS'],image_channels=IMAGE_CHANNELS)
model.summary()
tf.keras.utils.plot_model(model, show_shapes=True)

#For fine tuning
if params['MODE'] != 'training':
    model = tf.keras.models.load_model(MODEL_DIR)

# Train

In [None]:
# Load the TensorBoard notebook extension.
%load_ext tensorboard

from datetime import datetime
from packaging import version

import tensorboard
tensorboard.__version__

# Define the Keras TensorBoard callback.
!mkdir 'model_deSpeckNet_DUAL_aug_mask_v1'
logdir= 'model_deSpeckNet_DUAL_aug_mask_v1'
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_schedule)

In [None]:
if params['TV_LOSS']:
  loss_funcs = {'clean1': 'mean_squared_error','clean1':TVloss,'noisy1' : 'mean_squared_error'}
  loss_weights = {'clean1': 100.0, 'clean1':0.0, 'noisy1': 1.0}
else:
    loss_funcs = {'clean1': 'mean_squared_error','noisy1' : 'mean_squared_error'}
    loss_weights = {'clean1': 1.0,'noisy1': 100.0}

#Compile
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=loss_funcs, loss_weights=loss_weights)

In [None]:

model.fit(
    x=training, 
    epochs=EPOCHS,
    steps_per_epoch=int(TRAIN_SIZE / params['BATCH_SIZE']), 
    validation_data=evaluation,
    validation_steps=int(VALIDATION_SIZE / params['BATCH_SIZE']),
    callbacks=[tensorboard_callback, lr_scheduler])

In [None]:
%tensorboard --logdir 'model_deSpeckNet_DUAL_aug_mask_v1'

In [None]:
if params['MODE'] == 'training':
# Save the trained model
  model.save(MODEL_DIR, save_format='tf')
else:
  MODEL_DIR = 'gs://' + params['BUCKET'] + '/' + params['FOLDER'] + '/' + 'tune'
  model.save(MODEL_DIR, save_format='tf')