In [None]:
import tensorflow
from tensorflow.keras.utils import normalize, to_categorical, set_random_seed as keras_set_random_seed
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import roc_auc_score
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping


from .model import PRC_Net
from .Utils.load_train_data import *
from .Utils.loss import ccfl_dice


In [None]:
channel = 'RGBNIRRE'
batch_size = 6
image_size = 320 
epoch_nums = 142
patience = 63
n_classes = 3
aug_degree = 2
augs = [1,2]
aug_s = '12'
l2_reg = 0.0 
patience_lr = 30
factor_lr = 0.5

In [None]:
train_images = load_train_image_arrays_selective_augs(image_size=image_size, dir_path="PATH_TO_IMG_DIR", augs = augs)
train_masks = load_train_masks_selective_augs(image_size=image_size, dir_path='PATH_TO_MASK_DIR', augs = augs)

In [None]:
labelencoder = LabelEncoder()
n, h, w = train_masks.shape
train_masks_reshaped = train_masks.reshape(-1,1)
train_masks_reshaped_encoded = labelencoder.fit_transform(train_masks_reshaped)
train_masks_encoded_original_shape = train_masks_reshaped_encoded.reshape(n, h, w)
np.unique(train_masks_encoded_original_shape)
train_images = normalize(train_images, axis=-1)
train_masks_input = np.expand_dims(train_masks_encoded_original_shape, axis=3)
X_train, X_val, y_train, y_val = train_test_split(train_images, train_masks_input, test_size = 0.2, random_state = 0)
train_masks_cat = to_categorical(y_train, num_classes=n_classes)
y_train_cat = train_masks_cat.reshape((y_train.shape[0], y_train.shape[1], y_train.shape[2], n_classes))
val_masks_cat = to_categorical(y_val, num_classes=n_classes)
y_val_cat = val_masks_cat.reshape((y_val.shape[0], y_val.shape[1], y_val.shape[2], n_classes))

In [None]:
model =PRC_Net(n_classes=n_classes, IMG_HEIGHT=X_train.shape[1], IMG_WIDTH=X_train.shape[2], IMG_CHANNELS=X_train.shape[3], dropout_rate=0.0, l2_reg=l2_reg)
model.compile(optimizer='adam', loss=ccfl_dice, metrics=['accuracy', tensorflow.keras.metrics.AUC(name='auc')]) 
model.summary()

In [None]:

callbacks = [
    EarlyStopping(patience=patience, monitor='val_loss'),
    ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min'),
    ReduceLROnPlateau(monitor='val_loss', factor=factor_lr, patience=patience_lr, min_lr=1e-6)
]

history = model.fit(X_train, y_train_cat, 
                    batch_size = batch_size, 
                    verbose=1, 
                    epochs=epoch_nums, 
                    validation_data=(X_val, y_val_cat), 
                    shuffle=False,
                    callbacks=callbacks)