In [None]:
import tensorflow as tf
import os
os.environ['SM_FRAMEWORK'] = 'tf.keras'
import numpy as np
import pandas as pd
import random
from tqdm.notebook import tqdm

from sklearn.model_selection import train_test_split
from segmentation_models.losses import dice_loss
from segmentation_models.metrics import iou_score

from GlottisNetV2.Utils.DataGenerator import DataGenerator
from GlottisNetV2.Utils.data import load_data, metric_mape, mape_ap, mape_pp
from GlottisNetV2.Utils.Callbacks import get_callbacks

from GlottisNetV2.Models.GlottisNetV2_a import glottisnetV2_a
from GlottisNetV2.Models.GlottisNetV2_b import glottisnetV2_b
from GlottisNetV2.Models.GlottisNetV2_c import glottisnetV2_c
from GlottisNetV2.Models.GlottisNetV2_d import glottisnetV2_d
from GlottisNetV2.Models.GlottisNetV2_e import glottisnetV2_e

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
# Set path to the text file with coordinates of the anterior and posterior points
coord_train = r"Set path to JSON-file with AP points" #cTODO

# Set path to training data
img_training = r"Set path to training images" # TODO

N_train = 100# 55750 # number of training images

# Create video IDs for training data and save them in Pandas Dataframe
cols = ['z','path']
df_imgs_train = pd.DataFrame(columns= cols)
df_segs_train = pd.DataFrame(columns =cols)

for i in tqdm(range(N_train)):
    row_imgs = {'z' : [i], 'path': [img_training + "\\" + str(i) + ".png"]}
    row_segs = {'z': [i], 'path': [img_training + "\\" + str(i) + '_seg.png']} 
    df_imgs_train = pd.concat([df_imgs_train, pd.DataFrame(row_imgs)])
    df_segs_train = pd.concat([df_segs_train, pd.DataFrame(row_segs)])

print('Created IDs for training images.')

# Save coordinates of anterior and posterior points in Pandas Dataframe
training_data = load_data(coord_train, N_train)
print('Loaded anterior and posterior points to dataframe.')

In [None]:
'''Training'''
# Set random seed for reproducible training
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
rand=np.random.seed(SEED)
random.seed(SEED)
tf.compat.v1.set_random_seed(SEED)

# Set parameters
BATCH_SIZE = 8
FILTERS = 16
LAYERS= 4
LEARNING_RATE = 0.2e-3
EPOCHS = 30
TARGET_HEIGHT = 512
TARGET_WIDTH = 256
SHUFFLE = True
AUGMENT = True
MODEL_PATH = r"Set model path" # TODO
STEPS_PATH = r"Set path to model checkpoints" # TODO
N_STEPS = 20 # Save every #N_STEPS epoch
RADIUS = 15

model = glottisnetV2_e(input_size=(TARGET_HEIGHT, TARGET_WIDTH, 1), layers=LAYERS, filters=FILTERS)

# Hard split of training and validation data 
train_imgs, val_imgs, train_segs, val_segs = train_test_split(df_imgs_train, 
                                                              df_segs_train,  
                                                              test_size = 0.1, 
                                                              random_state = SEED)

# Training data --> Augmentation and Shuffle
training_generator = DataGenerator(train_imgs, train_segs, batch_size = BATCH_SIZE, target_height = TARGET_HEIGHT, \
                                   target_width = TARGET_WIDTH, shuffle = SHUFFLE, df_coordinates = training_data, \
                                   augment = AUGMENT, radius=RADIUS)

# Validation data
validation_generator = DataGenerator(val_imgs, val_segs, target_height = TARGET_HEIGHT, \
                                     target_width = TARGET_WIDTH, batch_size = BATCH_SIZE, shuffle = False, \
                                     df_coordinates = training_data, augment = False, radius=RADIUS) 

# Compile model with dice_loss for segmentation, mse for prediction maps and use Adam as optimizer
# First exit: predictions of anterior and posterior points (2 channels)
# Second exit: Segmentations
model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = LEARNING_RATE), \
              metrics = {'seg': ['acc', iou_score],
                         'ap_pred': ['acc', metric_mape, mape_ap, mape_pp]},
              loss = {'ap_pred': 'mse', 'seg': dice_loss}, run_eagerly=True)

# Train model on dataset and save it
model.fit(training_generator, validation_data= validation_generator, epochs = EPOCHS, 
                    callbacks = get_callbacks(MODEL_PATH, model, N_STEPS, STEPS_PATH))

