**ResNet50** for classication of PCam images 

In [28]:
import tensorflow as tf 
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt

from tensorflow.keras.applications.resnet50 import ResNet50

In [48]:
def freeze_model_weights(model): 
    for l in model.layers:
        l.trainable = False 


In [None]:
d1, d2, d3 = tfds.load('patch_camelyon', split=[f'train[98%:]',f'test[98%:]',f'validation[98%:]'],
                          data_dir='./Data/PCAM',
                          download=False,
                          shuffle_files=True)

In [43]:
def convert_sample(sample):
    image, label = sample['image'], sample['label']
    image = tf.image.convert_image_dtype(image, tf.float32)
    label = tf.one_hot(label, 2, dtype=tf.float32)
    return image, label

train_data = d1.map(convert_sample).batch(64)
validation_data = d3.map(convert_sample).batch(64)
test_data = d2.map(convert_sample).batch(64)

In [73]:
pretrained_ResNet50 = ResNet50(weights='imagenet', 
                       include_top = False, 
                       input_shape=(96,96,3))

In [80]:
# Freeze all weights 
freeze_model_weights(pretrained_ResNet50)

In [81]:
pretrained_ResNet50.summary()

Model: "resnet50"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_6 (InputLayer)        [(None, 96, 96, 3)]          0         []                            
                                                                                                  
 conv1_pad (ZeroPadding2D)   (None, 102, 102, 3)          0         ['input_6[0][0]']             
                                                                                                  
 conv1_conv (Conv2D)         (None, 48, 48, 64)           9472      ['conv1_pad[0][0]']           
                                                                                                  
 conv1_bn (BatchNormalizati  (None, 48, 48, 64)           256       ['conv1_conv[0][0]']          
 on)                                                                                       

In [82]:
# Create new instance of pre-trained reset model 
PCam_ResNet50 = tf.keras.models.Sequential()
PCam_ResNet50.add(pretrained_ResNet50)


# Add layers for binary classification 
PCam_ResNet50.add(tf.keras.layers.Flatten())
PCam_ResNet50.add(tf.keras.layers.Dense(250, activation='relu')),
PCam_ResNet50.add(tf.keras.layers.Dense(2, activation='softmax'))

In [83]:
PCam_ResNet50.summary()

Model: "sequential_8"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 resnet50 (Functional)       (None, 3, 3, 2048)        23587712  
                                                                 
 flatten_7 (Flatten)         (None, 18432)             0         
                                                                 
 dense_10 (Dense)            (None, 250)               4608250   
                                                                 
 dense_11 (Dense)            (None, 2)                 502       
                                                                 
Total params: 28196464 (107.56 MB)
Trainable params: 4608752 (17.58 MB)
Non-trainable params: 23587712 (89.98 MB)
_________________________________________________________________


In [84]:
PCam_ResNet50.compile(optimizer='adam', 
                      loss = 'categorical_crossentropy', 
                      metrics=['accuracy'])

In [87]:
early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', 
                                                           patience=5,
                                                           restore_best_weights=True)

hist = PCam_ResNet50.fit(train_data, 
                         validation_data=validation_data,
                         epochs=50, 
                         callbacks=[early_stopping_callback])

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
