In [1]:
# this is training unet from pretrained imagenet
import os
# import cv2
import numpy as np
from glob import glob
from scipy.io import loadmat
import matplotlib.pyplot as plt
from time import time
import tensorflow as tf
from tensorflow import keras
import segmentation_models as sm

Segmentation Models: using `keras` framework.


In [2]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  1


In [3]:
IMAGE_SIZE = 1024
BATCH_SIZE = 4
NUM_CLASSES = 13
DATA_DIR = r"\\fatherserverdw\Q\research\images\skin_aging\deeplab_trainingset\v11_fold5"
train_images = sorted(glob(os.path.join(*[DATA_DIR, 'training',"im/*"])))
train_masks = sorted(glob(os.path.join(*[DATA_DIR, 'training',"label/*"])))
val_images = sorted(glob(os.path.join(*[DATA_DIR, 'validation',"im/*"])))
# if thumbs.db in list
val_images.remove('\\\\fatherserverdw\\Q\\research\\images\\skin_aging\\deeplab_trainingset\\v11_fold5\\validation\\im\\Thumbs.db')
val_masks = sorted(glob(os.path.join(*[DATA_DIR, 'validation',"label/*"])))
val_masks.remove('\\\\fatherserverdw\\Q\\research\\images\\skin_aging\\deeplab_trainingset\\v11_fold5\\validation\\label\\Thumbs.db')


In [4]:
def read_image(image_path, mask=False):
    image = tf.io.read_file(image_path)
    # image = cv2.imread(image_path.decode('UTF-8'))
    if mask:
        image = tf.image.decode_png(image, channels=1)
        image.set_shape([None, None, 1])
        image = tf.image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
        # image[image==13]=0
        image = tf.where(tf.equal(image, 13), tf.zeros_like(image), image)
    else:
        image = tf.image.decode_png(image, channels=3)
        image.set_shape([None, None, 3])
        image = tf.image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
        # image = image / 127.5 - 1 #[-1 1]
        image = image/255 #[0 1]
    return image

def load_data(image_list, mask_list):
    image = read_image(image_list)
    mask = read_image(mask_list, mask=True)
    return image, mask

def data_generator(image_list, mask_list):
    dataset = tf.data.Dataset.from_tensor_slices((image_list, mask_list))
    dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
    # dataset = dataset.map(lambda x,y: tf.numpy_function(load_data, [x,y], Tout=tf.uint8))
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    return dataset

In [5]:
train_dataset = data_generator(train_images, train_masks)
val_dataset = data_generator(val_images, val_masks)
print("Train Dataset:", train_dataset)
print("Val Dataset:", val_dataset)

Train Dataset: <BatchDataset shapes: ((4, 1024, 1024, 3), (4, 1024, 1024, 1)), types: (tf.float32, tf.float32)>
Val Dataset: <BatchDataset shapes: ((4, 1024, 1024, 3), (4, 1024, 1024, 1)), types: (tf.float32, tf.float32)>


In [6]:
BACKBONE = 'resnet50'
sm.set_framework('tf.keras')
sm.framework()

model = sm.Unet(BACKBONE, input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), encoder_weights='imagenet', classes=NUM_CLASSES, activation=None)
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
data (InputLayer)               [(None, 1024, 1024,  0                                            
__________________________________________________________________________________________________
bn_data (BatchNormalization)    (None, 1024, 1024, 3 9           data[0][0]                       
__________________________________________________________________________________________________
zero_padding2d (ZeroPadding2D)  (None, 1030, 1030, 3 0           bn_data[0][0]                    
__________________________________________________________________________________________________
conv0 (Conv2D)                  (None, 512, 512, 64) 9408        zero_padding2d[0][0]             
____________________________________________________________________________________________

In [7]:
## Training

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(
    # optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss=loss,
    metrics=["accuracy"],
)

checkpoint_path = "fold_5_pretrained_v2_gpu/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 monitor='val_accuracy',
                                                 mode='max',
                                                 save_best_only=True,
                                                 verbose=1,
                                                 )
earlystop = keras.callbacks.EarlyStopping(
        # Stop training when `val_loss` is no longer improving
        monitor="val_loss",
        # "no longer improving" being defined as "no better than 1e-2 less"
        min_delta=1e-2,
        # "no longer improving" being further defined as "for at least 10 epochs"
        patience=10,
        verbose=1,
    )

def scheduler(epoch, lr):
  if epoch < 10:
    return lr
  else:
    if epoch%2:
        return lr
    else:
        return lr * 0.1

lrschedule = tf.keras.callbacks.LearningRateScheduler(scheduler,verbose=1)
round(model.optimizer.lr.numpy(), 5)


1e-04

In [None]:
start = time()
history = model.fit(train_dataset, validation_data=val_dataset, epochs=40,callbacks=[cp_callback,earlystop,lrschedule])
print(np.around(time()-start),'seconds elapsed')
round(model.optimizer.lr.numpy(), 5)

Epoch 1/40

Epoch 00001: LearningRateScheduler setting learning rate to 9.999999747378752e-05.


In [None]:
import time
start = time.time()
y = model.predict(val_dataset)
end = time.time()
print(end-start)

In [None]:
print(y.shape)
pred_mask = tf.argmax(y[1], axis=-1)
print(np.unique(pred_mask))
plt.imshow(pred_mask)

In [None]:
plt.plot(history.history["loss"])
plt.title("Training Loss")
plt.ylabel("loss")
plt.xlabel("epoch")
plt.show()

plt.plot(history.history["accuracy"])
plt.title("Training Accuracy")
plt.ylabel("accuracy")
plt.xlabel("epoch")
plt.show()

plt.plot(history.history["val_loss"])
plt.title("Validation Loss")
plt.ylabel("val_loss")
plt.xlabel("epoch")
plt.show()

plt.plot(history.history["val_accuracy"])
plt.title("Validation Accuracy")
plt.ylabel("val_accuracy")
plt.xlabel("epoch")
plt.show()