# Training a ResUnetA model

Training of a `ResUnetA` architecture with the `.npz` files from the previous notebooks

This notebook:

 * creates TensorFlow datasets using the npz files previously created. The datasets allow manipulation and loading on the fly, to reduce RAM load and processing of large AOIs
 * performs training of the model  
 * test the models predictions on a validation batch
 
## NOTE

This workflow can load the `.npz` files from disk (No S3 used).

In [1]:
import os 
import json
import logging
from datetime import datetime
from functools import reduce
from typing import Callable, Tuple, List

import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import wandb
from wandb.keras import WandbCallback

from eoflow.models.metrics import MCCMetric
from eoflow.models.segmentation_base import segmentation_metrics
from eoflow.models.losses import JaccardDistanceLoss, TanimotoDistanceLoss

from eoflow.models.segmentation_unets import ResUnetA

# ### Changing current directory 
# os.chdir('/home/lscalambrin/proyecto_integrador/segmentation/field-delineation-main')
# print(os.getcwd())

from fd.tf_viz_utils import ExtentBoundDistVisualizationCallback
from fd.training import TrainingConfig, get_dataset
from fd.utils import prepare_filesystem
from fd.metrics_extra import seg_metrics, mean_iou, mean_dice

from pprint import pprint
import sys

In [2]:
logging.getLogger('tensorflow').disabled = True
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

In [3]:
tf.config.experimental.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [4]:
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

Num GPUs Available:  1


## Set up configuration parameters

In [5]:
### paths
save_patchlet_npz = '/data/lscalambrin/proyecto_integrador/segmentation/pergamino/patchlets_npz'
df_path = '/data/lscalambrin/proyecto_integrador/segmentation/pergamino/patchlet-info.csv'
idx_model = 1
model_folder = f'/data/lscalambrin/proyecto_integrador/segmentation/pergamino/models/model{idx_model}'

n_classes = 2
batch_size = 5

training_config = TrainingConfig(
    bucket_name='bucket-name',
    aws_access_key_id='',
    aws_secret_access_key='',
    aws_region='eu-central-1',
    wandb_id=None, # change this with your wandb account 
    npz_folder=save_patchlet_npz,
    metadata_path=df_path,
    model_folder=model_folder,
    model_s3_folder='models/Castilla/2020-04',
    chkpt_folder=None,
#     chkpt_folder='/home/ubuntu/pre-trained-model/checkpoints',
    input_shape=(256, 256, 4),
    n_classes=n_classes,
    batch_size=batch_size,
    iterations_per_epoch=n_samples//batch_size, 
    num_epochs=20,
    model_name='resunet-a',
    reference_names=['extent','boundary','distance'],
    augmentations_feature=['flip_left_right', 'flip_up_down', 'rotate', 'brightness'],
    augmentations_label=['flip_left_right', 'flip_up_down', 'rotate'],
#   normalize posibles values:  'to_meanstd', 'to_medianstd', 'to_perc'
    normalize='to_meanstd',
    n_folds=2,
    model_config={
        'learning_rate': 0.0001,
        'n_layers': 3,
        'n_classes': n_classes,
        'keep_prob': 0.8,
        'features_root': 32,
        'conv_size': 3,
        'conv_stride': 1,
        'dilation_rate': [1, 3, 15, 31],
        'deconv_size': 2,
        'add_dropout': True,
        'add_batch_norm': False,
        'use_bias': False,
        'bias_init': 0.0,
        'padding': 'SAME',
        'pool_size': 3,
        'pool_stride': 2,
        'prediction_visualization': True,
        'class_weights': None
    }
)

In [6]:
### save model info
if not os.path.exists(model_folder):
    os.makedirs(model_folder)
original_stdout = sys.stdout # Save a reference to the original standard output

with open(model_folder +'/model_info.txt', 'w') as f:
    sys.stdout = f # Change the standard output to the file we created.
    pprint(vars(training_config))
    sys.stdout = original_stdout # Reset the standard output to its original value

In [7]:
if training_config.wandb_id is not None:
    !wandb login {training_config.wandb_id}  # EOR

In [8]:
ds_folds = [get_dataset(training_config, fold=fold, augment=True, randomize=True,
                        num_parallel=200, npz_from_s3=False) 
            for fold in range(1, training_config.n_folds+1)]

Check an example

In [9]:
ds_fold_ex = ds_folds[0].batch(training_config.batch_size)

In [10]:
example_batch = next(iter(ds_fold_ex))

In [11]:
feats = example_batch[0]
lbls = example_batch[1]

In [12]:
feats['features'].shape, lbls['extent'].shape, lbls['boundary'].shape, lbls['distance'].shape 

(TensorShape([5, 256, 256, 4]),
 TensorShape([5, 256, 256, 2]),
 TensorShape([5, 256, 256, 2]),
 TensorShape([5, 256, 256, 2]))

In [None]:
fig, axs = plt.subplots(nrows=3, ncols=4, sharex='all', sharey='all', figsize=(20, 15))

for nb in np.arange(3):
    axs[nb][0].imshow(feats['features'].numpy()[nb][...,[2,1,0]])
    axs[nb][1].imshow(lbls['extent'].numpy()[nb][..., 1])
    axs[nb][2].imshow(lbls['boundary'].numpy()[nb][..., 1])
    axs[nb][3].imshow(lbls['distance'].numpy()[nb][..., 1])
    
plt.tight_layout()

### Set up model & Train 


In [15]:
def initialise_model(config: TrainingConfig, chkpt_folder: str = None):
    """ Initialise ResUnetA model 
    
    If an existing chekpoints directory is provided, the existing weights are loaded and 
    training starts from existing state
    """
    mcc_metric = MCCMetric(default_n_classes=n_classes, default_threshold=.5)
    mcc_metric.init_from_config({'n_classes': n_classes})
    
    model = ResUnetA(training_config.model_config)
    
    model.build(dict(features=[None] + list(training_config.input_shape)))
    
    model.net.compile(
        loss={'extent':TanimotoDistanceLoss(from_logits=False),
              'boundary':TanimotoDistanceLoss(from_logits=False),
              'distance':TanimotoDistanceLoss(from_logits=False)},
        optimizer=tf.keras.optimizers.Adam(
            learning_rate=training_config.model_config['learning_rate']),
        # comment out the metrics you don't care about
        metrics=[segmentation_metrics['accuracy'](),
#                  tf.keras.metrics.MeanIoU(num_classes=training_config.n_classes),
                 mean_iou, mean_dice]
    )
    
    if chkpt_folder is not None:
        model.net.load_weights(f'{chkpt_folder}/model.ckpt')
        
    return model


def initialise_callbacks(config: TrainingConfig, 
                         fold: int) -> Tuple[str, List[Callable]]:
    """ Initialise callbacks used for logging and saving of models """
    now = datetime.now().isoformat(sep='-', timespec='seconds').replace(':', '-')
    model_path = f'{training_config.model_folder}/{training_config.model_name}_fold-{fold}_{now}'

    if not os.path.exists(model_path):
        os.makedirs(model_path)

    logs_path = os.path.join(model_path, 'logs')
    checkpoints_path = os.path.join(model_path, 'checkpoints', 'model.ckpt')


    # Tensorboard callback
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logs_path,
                                                          update_freq='epoch',
                                                          profile_batch=0)

    # Checkpoint saving callback
    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(checkpoints_path,
                                                             save_best_only=True,
                                                             save_freq='epoch',
                                                             save_weights_only=True)

    full_config = dict(**training_config.model_config, 
                       iterations_per_epoch=training_config.iterations_per_epoch, 
                       num_epochs=training_config.num_epochs, 
                       batch_size=training_config.batch_size,
                       model_name=f'{training_config.model_name}_{now}'
                      )

    # Save model config 
    with open(f'{model_path}/model_cfg.json', 'w') as jfile:
        json.dump(training_config.model_config, jfile)

    # initialise wandb if used
    if training_config.wandb_id:
        wandb.init(config=full_config, 
                   name=f'{training_config.model_name}-leftoutfold-{fold}',
                   project="field-delineation", 
                   sync_tensorboard=True)
        
    callbacks = [tensorboard_callback, 
                 checkpoint_callback, 
                ] + ([WandbCallback()] if training_config.wandb_id is not None else [])
    
    return model_path, callbacks 

def plot_epochs(training_config, h, testing_id):
    now = datetime.now().isoformat(sep='-', timespec='seconds').replace(':', '-')
    model_path = f'{training_config.model_folder}/{training_config.model_name}_plots'

    if not os.path.exists(model_path):
        os.makedirs(model_path)
        
    np.save(model_path +  f'/data_train_{testing_id[0]+1}.npy',h)
    
    epochs = training_config.num_epochs
#     tmp = np.load('data_train_1.npy',allow_pickle=True).item()
    tmp = h
    epoch = np.arange(epochs)+1

    loss = tmp['loss']
    val_loss = tmp['val_loss']

    extent_loss = tmp['extent_loss']
    val_extent_loss = tmp['val_extent_loss']
    boundary_loss = tmp['boundary_loss']
    val_boundary_loss = tmp['val_boundary_loss']
    distance_loss = tmp['distance_loss']
    val_distance_loss = tmp['val_distance_loss']

    extent_accuracy = tmp['extent_accuracy']
    val_extent_accuracy = tmp['val_extent_accuracy']
    boundary_accuracy = tmp['boundary_accuracy']
    val_boundary_accuracy = tmp['val_boundary_accuracy']
    distance_accuracy = tmp['distance_accuracy']
    val_distance_accuracy = tmp['val_distance_accuracy']
    
    ### loss
    fig = plt.figure(figsize=[10, 5])
    ax = fig.add_subplot(1, 1, 1)
    ax.plot(epoch,extent_loss, linewidth=2, marker='o',c='#7f7f7f')
    ax.plot(epoch,boundary_loss, linewidth=2, marker='o',c='#17becf')
    ax.plot(epoch,distance_loss, linewidth=2, marker='o',c='#9467bd')

    ax.plot(epoch,val_extent_loss, linewidth=2, linestyle='dashed',c='#7f7f7f')
    ax.plot(epoch,val_boundary_loss, linewidth=2, linestyle='dashed',c='#17becf')
    ax.plot(epoch,val_distance_loss, linewidth=2, linestyle='dashed',c='#9467bd')
    ax.legend(['Extent loss (training set)',
               'Boundary loss (training set)',
               'Distance loss (training set)',
               'Validation set'],fontsize=14,shadow=True)
    ax.grid()
    ax.set_xlabel('Epoch', fontsize=16)
    ax.set_ylabel('Loss function', fontsize=16)
    # plt.title('Alexnet(edit) - CCE',fontsize=14)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.savefig(model_path + f'/loss_{testing_id[0]+1}.pdf')
    plt.show()

    ### acc
    fig = plt.figure(figsize=[10, 5])
    ax = fig.add_subplot(1, 1, 1)
    ax.plot(epoch,extent_accuracy, linewidth=2, marker='o',c='#7f7f7f')
    ax.plot(epoch,boundary_accuracy, linewidth=2, marker='o',c='#17becf')
    ax.plot(epoch,distance_accuracy, linewidth=2, marker='o',c='#9467bd')

    ax.plot(epoch,val_extent_accuracy, linewidth=2, linestyle='dashed',c='#7f7f7f')
    ax.plot(epoch,val_boundary_accuracy, linewidth=2, linestyle='dashed',c='#17becf')
    ax.plot(epoch,val_distance_accuracy, linewidth=2, linestyle='dashed',c='#9467bd')
    ax.legend(['Extent (training set)',
               'Boundary (training set)',
               'Distance (training set)',
               'Validation set'],fontsize=14,shadow=True)
    ax.grid()
    ax.set_xlabel('Epoch', fontsize=16)
    ax.set_ylabel('Accuracy', fontsize=16)
    # plt.title('Alexnet(edit) - CCE',fontsize=14)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.savefig(model_path + f'/acc_{testing_id[0]+1}.pdf')
    plt.show()



Indices defining which dataset folds to consider

In [None]:
folds = list(range(training_config.n_folds))

folds_ids_list = [(folds[:nf] + folds[1 + nf:], [nf]) for nf in folds]

folds_ids_list = [folds_ids_list[1]]

folds_ids_list

In [None]:
for training_ids, testing_id in folds_ids_list:
    print(training_ids)
    print(testing_id)

### Train models


In [None]:
np.random.seed(training_config.seed)

models = []
model_paths = []

for training_ids, testing_id in folds_ids_list:
    
    left_out_fold = testing_id[0]+1
    print(f'Training model for left-out fold {left_out_fold}')
    
    ### No k-cross validation
    fold_val = testing_id[0]
    folds_train = training_ids
    print(f'Train folds {folds_train}, Val fold: {fold_val}, Test fold: {testing_id[0]}')

# bug fixed using tf.data.Dataset.concatenate
#     ds_train = ds_folds[folds_train[0]]
#     for ft in folds_train[1:]:
#         ds_train.concatenate(ds_folds[ft])
    ds_folds_train = [ds_folds[tid] for tid in folds_train]
    ds_train = reduce(tf.data.Dataset.concatenate, ds_folds_train)
    
    
    ds_val = ds_folds[fold_val]
    
    ds_val = ds_val.batch(training_config.batch_size)
    
    ds_train = ds_train.batch(training_config.batch_size)
    ds_train = ds_train.repeat()
    
    print(type(ds_train))
    
    # Get model
    model = initialise_model(training_config, chkpt_folder=training_config.chkpt_folder)
    
    # Set up callbacks to monitor training
    model_path, callbacks = initialise_callbacks(training_config, 
                                                 fold=left_out_fold)
    
    print(f'\tTraining model, writing to {model_path}')
    

    hist = model.net.fit(ds_train, 
                  validation_data=ds_val,
                  epochs=training_config.num_epochs,
                  steps_per_epoch=training_config.iterations_per_epoch,
                  callbacks=callbacks, verbose=1)
    
    
    plot_epochs(training_config,hist.history,testing_id)
    
    models.append(model)
    model_paths.append(model_path)
    
    del fold_val, folds_train, ds_train, ds_val

### Check some validation results

In [None]:
test_batch = next(iter(ds_folds[1].batch(batch_size)))

In [None]:
predictions = model.net.predict(test_batch[0]['features'].numpy())

In [None]:
n_images = 3

fig, axs = plt.subplots(nrows=n_images, ncols=5, 
                        sharex='all', sharey='all', 
                        figsize=(15, 3*n_images))

for nb in np.arange(n_images):
    axs[nb][0].imshow(test_batch[0]['features'].numpy()[nb][...,[2,1,0]])
    axs[nb][1].imshow(predictions[0][nb][..., 1])
    axs[nb][2].imshow(predictions[1][nb][..., 1])
    axs[nb][3].imshow(predictions[2][nb][..., 1])
    axs[nb][4].imshow(test_batch[1]['extent'].numpy()[nb][..., 1])
    
plt.tight_layout()

## Evaluate models on test dataset

Once we are happy with the hyper-parameters, we can test the performance of the models on the left-out test dataset.

NOTE: bear in mind that this score is computed on augmented samples, for a better score estimation recreate the datasets without augmentation.

In [None]:
testing_id[0]

In [None]:
for _, testing_id in folds_ids_list:
    
    left_out_fold = testing_id[0]+1
    print(f'Evaluating model on left-out fold {left_out_fold}')
    
    model.net.evaluate(ds_folds[testing_id[0]].batch(training_config.batch_size))
    
    print('\n\n')