### Transfer learning

In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [2]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Keep classes 0,1,2 only
mask_train = (y_train < 3).flatten()
mask_test  = (y_test < 3).flatten()

x_train = x_train[mask_train].astype("float32")
y_train = y_train[mask_train].flatten()

x_test  = x_test[mask_test].astype("float32")
y_test  = y_test[mask_test].flatten()

print(x_train.shape, y_train.shape)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 0us/step
(15000, 32, 32, 3) (15000,)


Processing: resizing makes images be right size. process _input makes values right format for MobileNetV2

In [3]:
IMG_SIZE = 96  # можно 128, но 96 быстрее

preprocess = tf.keras.Sequential([
    tf.keras.layers.Resizing(IMG_SIZE, IMG_SIZE),
    tf.keras.layers.Lambda(tf.keras.applications.mobilenet_v2.preprocess_input)
])

Pretrained + freez 

In [4]:
base = tf.keras.applications.MobileNetV2(
    input_shape=(IMG_SIZE, IMG_SIZE, 3),
    include_top=False,
    weights="imagenet"
)

base.trainable = False  # замораживаем веса

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_96_no_top.h5
[1m9406464/9406464[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


new head. 

In [5]:
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(32, 32, 3)),
    preprocess,
    base,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(3, activation="softmax")
])

model.summary()




Compile + fit

In [6]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

history = model.fit(
    x_train, y_train,
    epochs=5,
    batch_size=64,
    validation_split=0.1,
    verbose=1
)


Epoch 1/5
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m66s[0m 258ms/step - accuracy: 0.9083 - loss: 0.2417 - val_accuracy: 0.9573 - val_loss: 0.1238
Epoch 2/5
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m51s[0m 242ms/step - accuracy: 0.9555 - loss: 0.1264 - val_accuracy: 0.9640 - val_loss: 0.1092
Epoch 3/5
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 256ms/step - accuracy: 0.9607 - loss: 0.1075 - val_accuracy: 0.9627 - val_loss: 0.1018
Epoch 4/5
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m52s[0m 245ms/step - accuracy: 0.9613 - loss: 0.1025 - val_accuracy: 0.9653 - val_loss: 0.1007
Epoch 5/5
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m58s[0m 273ms/step - accuracy: 0.9644 - loss: 0.0934 - val_accuracy: 0.9580 - val_loss: 0.1231


Evaluate

In [7]:
loss, acc = model.evaluate(x_test, y_test, verbose=0)
print("Test accuracy (transfer learning, frozen base):", acc)

Test accuracy (transfer learning, frozen base): 0.9589999914169312
