Transfer Learning with MobileNetV3

<a name='1'></a>
# Packages

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow.keras.layers as tfl
import tensorflow_addons as tfa
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from keras_flops import get_flops

from train_utils.callbacks import LossHistory, LRCallBack
from train_utils.metrics import TopKAccuracy
from train_utils.utils import is_in
from train_utils.losses import binary_focal_crossentropy_loss
from train_utils.lite_accuracy import evaluate_model
from train_utils.metadata_writer_for_image_classifier import generate_metadata
import tensorflow_model_optimization as tfmot

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

quantize_model = tfmot.quantization.keras.quantize_model

Num GPUs Available:  1


# Datasets

In [2]:
BATCH_SIZE = 32
PREFETCH = 2
IMG_SIZE = (224, 224)
IMG_SHAPE = IMG_SIZE + (3,)
NUM_CLASSES = 133

In [3]:
train_dataset, train_info = tfds.load(
    'coco/2017_panoptic',
    split='train',
    batch_size=BATCH_SIZE,
    data_dir=r'C:\tensorflow_datasets',
    download=True,
    shuffle_files=True,
    with_info=True,
)

cv_dataset, cv_info = tfds.load(
    'coco/2017_panoptic',
    split='validation',
    batch_size=BATCH_SIZE,
    data_dir=r'C:\tensorflow_datasets',
    download=True,
    shuffle_files=True,
    with_info=True,
)

BATCHES = tf.data.experimental.cardinality(train_dataset).numpy()
print(
    BATCHES,
    cv_info.features,
    sep = '\n'
)

3697
FeaturesDict({
    'image': Image(shape=(None, None, 3), dtype=tf.uint8),
    'image/filename': Text(shape=(), dtype=tf.string),
    'image/id': tf.int64,
    'panoptic_image': Image(shape=(None, None, 3), dtype=tf.uint8),
    'panoptic_image/filename': Text(shape=(), dtype=tf.string),
    'panoptic_objects': Sequence({
        'area': tf.int64,
        'bbox': BBoxFeature(shape=(4,), dtype=tf.float32),
        'id': tf.int64,
        'is_crowd': tf.bool,
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=133),
    }),
})


# Preprocess Data

In [4]:
def preprocess_img(img):
    processor = tf.keras.Sequential([
        tfl.Resizing(
            IMG_SHAPE[0],
            IMG_SHAPE[1],
            crop_to_aspect_ratio=True,
        )
    ])
    return processor(img)

def preprocess_objects(imgs, objects):
    batch_size = imgs.shape[0]
    y = np.zeros([batch_size, NUM_CLASSES])
    imgs = imgs.numpy()
    batched_bboxs, batched_labels = objects['bbox'].numpy(), objects['label'].numpy()
    for i in range(batch_size):
        size, bboxs, labels, = imgs[i].shape, batched_bboxs[i], batched_labels[i]
        sw, sh = (1, size[0]/size[1]) if (size[0] > size[1]) else (size[1]/size[0], 1)
        for bbox, label in zip(bboxs, labels):
            ch, cw, _, _ = bbox
            if ch * sh < 1 and cw * sw < 1:
                y[i, label] = 1
    return tf.constant(y, dtype=tf.float32)
    
def wrapper(func, inp, Tout, name=None):
    def wrapped_func(*flat_inp):
        reconstructed_inp = tf.nest.pack_sequence_as(inp, flat_inp,
                                                     expand_composites=True)
        out = func(*reconstructed_inp)
        return out
    flat_out = tf.py_function(
        func=wrapped_func, 
        inp=tf.nest.flatten(inp, expand_composites=True),
        Tout=Tout,
        name=name)
    return flat_out

def preprocess(imgs, objects):
    imgs = preprocess_img(imgs)
    objects = preprocess_objects(imgs, objects)
    return imgs, objects

train_dataset = train_dataset.map(
                        lambda x: wrapper(
                            preprocess, 
                            [x['image'], x['panoptic_objects']], 
                            (tf.float32, tf.float32)), 
                        num_parallel_calls=PREFETCH).shuffle(
                            1024, reshuffle_each_iteration=True
                        ).prefetch(buffer_size=PREFETCH)
cv_dataset = cv_dataset.map(
                        lambda x: wrapper(
                            preprocess, 
                            [x['image'], x['panoptic_objects']], 
                            (tf.float32, tf.float32)), 
                        num_parallel_calls=PREFETCH,)

# Hyper-Parameters

In [5]:
## 0 to skip
GRADIENT_CLIP = 10     # [5, 100]
FIXED_LAYERS = 0       # [0, 276]
LR = (-3.5, -4.5)        # [-5.5, -2.5]
EPOCHS = 8            # [0, 30]
BETA = 0.5             # (0, 1)   
GAMMA = 4.0            # [2, 6]
DROPOUT = 0.3          # [0.2, 1)
L1, L2 = 0, 1e-4       # [0, 1e-4]

# Define Model

In [6]:
class PercentageModel(tf.keras.Model):

  def compute_metrics(self, x, y, y_pred, sample_weight):
    metric_results = super(PercentageModel, self).compute_metrics(
        x, y, y_pred, sample_weight)
    for k in metric_results:
        metric_results[k] *= 100
    return metric_results

In [7]:
def mobilenet_model():
    model = tf.keras.applications.MobileNetV3Large(
        input_shape=IMG_SHAPE,
        include_top=False,
        dropout_rate=DROPOUT,
    )
    model.trainable = True
    for layer in model.layers[:FIXED_LAYERS]:
        layer.trainable = False
    return model


def top_model(input_shape):
    inputs = tf.keras.Input(
        name='Input',
        shape=input_shape,)
    x = inputs
    
    x = tf.keras.layers.GlobalAveragePooling2D(
        name="GlobalAveragePool2D")(x)
    
    x = tfl.Dense(
        units=NUM_CLASSES,
        name='Dense_0',
        kernel_regularizer=tf.keras.regularizers.L1L2(L1, L2))(x)
    
    x = tf.keras.activations.sigmoid(x)
    
    outputs = x
    model = tf.keras.Model(
        inputs,
        outputs,
        name='Top')
    
    return model


def train_models():
    mobilenet = mobilenet_model()
    inputs = tf.keras.Input(
        name='Input',
        shape=IMG_SHAPE)
    top = top_model(
        input_shape=mobilenet.layers[-1].output_shape[1:])
    
    x = inputs
    x = mobilenet(x)
    x = top(x)
    
    outputs = x
    model = PercentageModel(
            inputs,
            outputs,
            name='full',)
    return {
        'mobilenet': mobilenet,
        'top': top,
        'train': model,}

## Estimate optimal learning rate

# Define Metrics

In [8]:
optimizer = tf.keras.optimizers.Adam(clipvalue=GRADIENT_CLIP)

metrics = [
    tfa.metrics.FBetaScore(
        num_classes=NUM_CLASSES,
        average='weighted',
        beta=0.1,
        threshold=BETA,
        name='F0.1',),
    tf.keras.metrics.Precision(
        thresholds=BETA,
        name='pre',),
    tf.keras.metrics.Recall(
        thresholds=BETA,
        name='rec',),]

models = train_models()
model = models['train']
model.compile(
    loss=tf.keras.losses.BinaryFocalCrossentropy(GAMMA),
    optimizer=optimizer,
    metrics=metrics,
    run_eagerly=True,
)
print(model.summary())

# print('Total Flops:', get_flops(model, batch_size=BATCH_SIZE) // 2)

Model: "full"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 Input (InputLayer)          [(None, 224, 224, 3)]     0         
                                                                 
 MobilenetV3large (Functiona  (None, 7, 7, 960)        2996352   
 l)                                                              
                                                                 
 Top (Functional)            (None, 133)               127813    
                                                                 
Total params: 3,124,165
Trainable params: 3,099,765
Non-trainable params: 24,400
_________________________________________________________________
None


# Train

In [None]:
tqdm_callback = tfa.callbacks.TQDMProgressBar(
    metrics_separator=' | ',
    epoch_bar_format='{n_fmt}/{total_fmt} | ETA: {remaining} | Elapsed: {elapsed} {bar} {desc}',
    metrics_format='{name}: {value:0.2f}%',
    update_per_second=1,
)

lr_callback = LRCallBack(epochs=EPOCHS, batches=BATCHES, l_r=LR)

history = model.fit(
    train_dataset,
    validation_data=cv_dataset,
    epochs=lr_callback.epochs,
    verbose=0,
    initial_epoch=0,
    callbacks=[tqdm_callback, lr_callback],)

Training:   0%|                                              0/8 ETA: ?s,  ?epochs/s

Epoch 1/8


0/3697 | ETA: ? | Elapsed: 00:00                                                    

LR set to: 0.000316


# Plot the training and validation accuracy:

In [None]:
y = lr_callback.batch_losses[int(BATCHES*0.2):]
x = np.linspace(0, len(y), len(y))
print(x)
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=x,
        y=y,
        mode='lines',
        name='Loss',
    )
)
fig.update_layout(
    template='plotly_dark', 
    height=400, width=1500, 
    title_text='Training Loss',
    xaxis_title='Batch',
    yaxis_title='Loss',
)
fig.show()

In [None]:
acc = history.history['F0.1']
val_acc = history.history['val_F0.1']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(12, 12))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
#plt.ylim([0,100])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
#plt.ylim([0,100])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

# Save the model

In [None]:
model.save('models\\main.h5')
model.save_weights('models\\main_weights.h5')
with open("models\\main_classes.txt", 'w') as file:
    for class_name in all_class_names:
        file.write(class_name + '\n')

# Fine Tune Quantization

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model2)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]

tflite_model = converter.convert()

interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)

print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)

In [None]:
open("models\\mobilenet_v3_large_food_classifier_op.tflite", "wb").write(tflite_model)
generate_metadata()

# Check Network Predictions

In [None]:
image_batch, label_batch = next(iter(cv_dataset))
plt.figure(figsize=(25, 25))
for item in range(9):
    ax = plt.subplot(3, 3, item + 1)
    plt.imshow(image_batch[item].numpy().astype("uint8"))
    probs = model2(image_batch)[item]
    label = all_class_names[tf.math.argmax(label_batch[item])]
    pred = all_class_names[tf.math.argmax(probs)]
    plt.title(label + '-' + pred)
    plt.axis("off")

In [None]:
from tensorflow.keras.preprocessing import image_dataset_from_directory

test_dataset = image_dataset_from_directory(
    'datasets\\test',
    batch_size=BATCH_SIZE,
    image_size=IMG_SIZE,
    label_mode='categorical',
)
test_labels = test_dataset.class_names


In [None]:
image_batch, label_batch = next(iter(test_dataset))
plt.figure(figsize=(25, 25))
for item in range(1):
    ax = plt.subplot(3, 3, item + 1)
    plt.imshow(image_batch[item].numpy().astype("uint8"))
    probs = model2(image_batch)[item]
    print(image_batch[item])
    label = test_labels[tf.math.argmax(label_batch[item])]
    pred = all_class_names[tf.math.argmax(probs)]
    plt.title(label + '-' + pred)
    plt.axis("off")