In [33]:

from woundSegmentation.models.unets import Unet2D
from woundSegmentation.models.deeplab import Deeplabv3
from woundSegmentation.models.FCN import VGG_19
from woundSegmentation.models.SegNet import SegNet

from woundSegmentation.utils.learning.metrics import dice_coef, precision, recall
from woundSegmentation.utils.learning.losses import dice_coef_loss
from woundSegmentation.utils.io.data import DataGen, save_history

from tensorflow.keras.optimizers import Adam
from keras.callbacks import EarlyStopping

import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot

import os
import time

%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [None]:
# Insira o modelo desejado
modelName = "fcn"

# Se o modelo é podado ou não
prune = True

In [None]:
# Seleção do caminho
if prune:
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    
print("Numero de GPUs disponiveis: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
# Modelos disponiveis
models = ["fcn", "mobilenetv2", "segnet", "unet"]


# Seleção do caminho
if prune:
    dirPath = 'woundSegmentation/results_prune/'
else:
    dirPath = 'woundSegmentation/results/'

# Criacao de arquivos
isdir = os.path.isdir(dirPath)
if not isdir:
    os.mkdir(dirPath)
    for i in range(len(models)):
        os.mkdir(dirPath + models[i])
        os.mkdir(dirPath + models[i] + "/training_history/")
        os.mkdir(dirPath + models[i] + "/datapredict/")
        
modelDirPath = dirPath + modelName + "/"

In [None]:
# Variaveis e gerador de dados
input_dim_x = 224
input_dim_y = 224
n_filters = 32
dataset = 'Foot Ulcer Segmentation Challenge'
datasetpath = 'woundSegmentation/data/' + dataset
#dataset= 'Medetec_foot_ulcer_224'
data_gen = DataGen('woundSegmentation/data/' + dataset + '/', split_ratio=0.2, x=input_dim_x, y=input_dim_y)

In [None]:
# Unet
if (modelName == "unet"):
    unet2d = Unet2D(n_filters=n_filters, input_dim_x=None, input_dim_y=None, num_channels=3)
    model, model_name = unet2d.get_unet_model_yuanqing()
    print("Modelo Unet Carregado!")

In [None]:
# SegNet
if (modelName == "segnet"):
    segNet = SegNet(n_filters=n_filters, input_dim_x=None, input_dim_y=None, num_channels=3)
    model, model_name = segNet.get_SegNet()
    print("Modelo Segnet Carregado!")

In [None]:
# MobilenetV2
if (modelName == "mobilenetv2"):
    model = Deeplabv3(input_shape=(input_dim_x, input_dim_y, 3), classes=1)
    model_name = 'MobilenetV2'
    print("Modelo MobilenetV2 Carregado!")

In [None]:
# FCN
if (modelName == "fcn"):
    model = VGG_19(input_shape=(input_dim_x, input_dim_y, 3))
    model_name = 'FCN'
    print("Modelo FCN Carregado!")

In [None]:
# Configuracoes de treinamento
batch_size = 2
epochs = 240
learning_rate = 1e-4
loss = 'binary_crossentropy'

In [None]:
# Tecnica de poda
if prune:
    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
    
    num_images = 810 * (1 - 0.2)
    end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
    
    pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.30, final_sparsity=0.70, begin_step=0, end_step=end_step)
    }

    model_pruned = prune_low_magnitude(model, **pruning_params)

In [None]:
# Cria o txt do resumo do modelo
def printmodelsummary(s):
    with open(modelDirPath + modelName + 'Modelsummary.txt','a') as f:
        print(s, file=f)

model_pruned.summary(print_fn=printmodelsummary)

In [None]:
# callback
if prune:
    callbacks = [
        EarlyStopping(monitor='val_dice_coef', patience=50, mode='max', restore_best_weights=True),
        tfmot.sparsity.keras.UpdatePruningStep()
    ]
else:
    callbacks = [
        EarlyStopping(monitor='val_dice_coef', patience=50, mode='max', restore_best_weights=True)
    ]

In [None]:
# Treinamento
start = time.time()
model_pruned.compile(optimizer=Adam(lr=learning_rate), loss=loss, metrics=[dice_coef, precision, recall])
training_history = model_pruned.fit_generator(data_gen.generate_data(batch_size=batch_size, train=True),
                                       steps_per_epoch=int(data_gen.get_num_data_points(train=True) / batch_size),
                                       callbacks=callbacks,
                                       validation_data=data_gen.generate_data(batch_size=batch_size, val=True),
                                       validation_steps=int(data_gen.get_num_data_points(val=True) / batch_size),
                                       epochs=epochs)
end = time.time()
file = open(modelDirPath + modelName +"Time.txt", "w")
file.write("Treinamento\n")
file.write(str(end-start))
file.close()

In [None]:
# Salva o modelo e o historico de treinamento
save_history(model, training_history, model_name, dataset, n_filters, epochs, learning_rate, loss, color_space='RGB',
             path= modelDirPath + "training_history/", name= modelName + "prunedmodelfile")