In [1]:
import tensorflow as tf
import copy
import warnings

from loss import fun_simclr_loss
from training_function import fun_train_simclr

In [18]:
#@title Hyper-parameters
labeled_input_num  = 1000

# learning rates
learning_rate_fsp_trf   = 0.01
learning_rate_fsp_fnt   = 0.0001

learning_rate_prx_trf   = 0.01
learning_rate_prx_fnt   = 0.000001

learning_rate_dwm_trf   = 0.01
learning_rate_dwm_fnt   = 0.0001


# batch sizes
batch_fsp_trf  = 128
batch_fsp_fnt  = 128

batch_prx_trf  = 64
batch_prx_fnt  = 64

batch_dwm_trf  = 128
batch_dwm_fnt  = 128

In [19]:
(train_data, train_labels), (test_data, test_labels) = tf.keras.datasets.cifar10.load_data()

train_data = train_data/255 # trasnform unit-8 values between 0 and 1
test_data = test_data/255 # trasnform unit-8 values between 0 and 1

train_labels = tf.keras.utils.to_categorical(train_labels)
test_labels = tf.keras.utils.to_categorical(test_labels)

print('Shape of train_data: {}'.format(train_data.shape))
print('Shape of test_data: {}'.format(test_data.shape))
print('Shape of train_labels: {}'.format(train_labels.shape))
print('Shape of test_labels: {}'.format(test_labels.shape))

Shape of train_data: (50000, 32, 32, 3)
Shape of test_data: (10000, 32, 32, 3)
Shape of train_labels: (50000, 10)
Shape of test_labels: (10000, 10)


In [20]:
# 2 augmentation functions

fun_augment_a  = tf.keras.layers.RandomCrop(height = 20, width = 20)
fun_augment_b  = tf.keras.layers.Resizing(height = train_data.shape[1], 
                                          width = train_data.shape[2])

fun_augment_01 = tf.keras.Sequential([fun_augment_a, fun_augment_b])
fun_augment_02 = tf.keras.layers.RandomRotation(factor = 0.2)

In [21]:
# randomly select labeled_data of training data
index_tr  = tf.experimental.numpy.random.randint(0, 
                                                 train_data.shape[0], 
                                                 labeled_input_num)

train_data_labeled = train_data[index_tr,:,:,:]
train_labels_labeled = train_labels[index_tr,:]

train_data_fsp = copy.deepcopy(train_data_labeled)
train_labels_fsp = copy.deepcopy(train_labels_labeled)

train_data_prx = copy.deepcopy(train_data)

train_data_dwm = copy.deepcopy(train_data_labeled)
train_labels_dwm = copy.deepcopy(train_labels_labeled)

# There are 50,000 training inputs; 1000 (labeled_input_num = 1000) of them are labeled

In [12]:
#Create model_fsp and model_dwm

input_layer = tf.keras.Input(shape=(train_data.shape[1], 
                                train_data.shape[2],
                                train_data.shape[3]))

upscale = tf.keras.layers.Lambda(lambda x: tf.image.resize_with_pad(x,
                                                                    160,
                                                                    160,
                                                                    method=tf.image.ResizeMethod.BILINEAR))(input_layer)

model_DenseNet121 = tf.keras.applications.DenseNet121(include_top  = False,
                                                      weights = "imagenet",
                                                      input_shape = (160,160,3),
                                                      input_tensor = upscale,
                                                      pooling = 'max')

In [13]:
model_base_fsp =  tf.keras.models.clone_model(model_DenseNet121)
model_base_prx =  tf.keras.models.clone_model(model_DenseNet121) # encoder

model_base_fsp.set_weights(model_DenseNet121.get_weights())
model_base_prx.set_weights(model_DenseNet121.get_weights())

batch_normalization_fsp = tf.keras.layers.BatchNormalization()
batch_normalization_prx = tf.keras.layers.BatchNormalization()

In [14]:
# SimCLR projector

layers_dense_prx = [tf.keras.Input(shape=(1024)),
                    tf.keras.layers.Dense(512, activation = 'relu'),
                    tf.keras.layers.Dense(128, activation = 'relu')]

model_projector = tf.keras.Sequential(layers_dense_prx)

In [15]:
# Create output layers of model_fsp.

layerou_fsp = tf.keras.layers.Dense(train_labels_fsp.shape[-1], activation = 'softmax')
#layerou_prx = tf.keras.layers.Dense(dataou_tr_prx.shape[-1], activation = 'softmax')

model_fsp   = tf.keras.models.Sequential([model_base_fsp, 
                                          batch_normalization_fsp, 
                                          layerou_fsp])

model_prx   = tf.keras.models.Sequential([model_base_prx, 
                                          batch_normalization_prx, 
                                          model_projector])

In [16]:
# Train the prx model using transfer learning and fine-tuning
# Transfer learning

model_base_prx.trainable = False
batch_normalization_prx.trainable = False

model_prx.compile(optimizer = tf.keras.optimizers.Adam(learning_rate_prx_trf), 
                  loss = fun_simclr_loss, 
                  metrics = 'accuracy')

model_prx.summary()

Model: "sequential_8"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 densenet121 (Functional)    (None, 1024)              7037504   
                                                                 
 batch_normalization_3 (Batc  (None, 1024)             4096      
 hNormalization)                                                 
                                                                 
 sequential_6 (Sequential)   (None, 128)               590464    
                                                                 
Total params: 7,632,064
Trainable params: 590,464
Non-trainable params: 7,041,600
_________________________________________________________________


In [17]:
model_prx, _ = fun_train_simclr(model_prx, 
                                train_data_prx, 
                                fun_augment_01, 
                                fun_augment_02, 
                                epochs = 5, 
                                batch_size = batch_prx_trf, 
                                verbose = 1, 
                                patience = 1)

KeyboardInterrupt: 

In [11]:

# Fine-tuning

model_base_prx.trainable = True
batch_normalization_prx.trainable = True

model_prx.compile(optimizer = tf.keras.optimizers.Adam(learning_rate_prx_fnt), 
                  loss = fun_simclr_loss, 
                  metrics = 'mean_squared_error')

model_prx, _ = fun_train_simclr(model_prx, 
                                train_data_prx, 
                                fun_augment_01, 
                                fun_augment_02, 
                                epochs = 1, 
                                batch_size = batch_prx_fnt, 
                                verbose = 1, 
                                patience = 1)

NameError: name 'model_prx' is not defined

In [22]:
# can also use Animal-10 kaggle dataset (take 4 classes instead of 10) => in my final model there will be 4 classes of scalp images
# Resnet / vgg16 instead of DenseNet121 (less layers)
# Increase the number of epochs