In [None]:
from tensorflow.python.compiler.mlcompute import mlcompute
mlcompute.set_mlc_device(device_name="any")

In [None]:
import tensorflow as tf
import numpy as np
import cv2
import os
import matplotlib.pylab as plt

In [None]:
from tensorflow import keras

## 参考
https://www.tensorflow.org/guide/keras/transfer_learning

In [None]:
import pathlib
import random

In [None]:
for i in pathlib.Path('./images').iterdir():
    print(i)

In [None]:
data_root = pathlib.Path( os.getcwd() +  '/images')

train_path = data_root / 'aug_train'
valid_path = data_root / 'aug_valid'

train_image_paths = list(train_path.glob('*/*'))
train_image_paths = [str(path) for path in train_image_paths]
random.shuffle(train_image_paths)
print(len(train_image_paths))

valid_image_paths = list(valid_path.glob('*/*'))
valid_image_paths = [str(path) for path in valid_image_paths]
random.shuffle(valid_image_paths)
print(len(valid_image_paths))

path = train_image_paths[0]

AUTOTUNE = tf.data.AUTOTUNE



In [None]:
label_names = sorted(item.name for item in train_path.glob('*/') if item.is_dir())
label_to_index = dict((name, index) for index, name in enumerate(label_names))

train_image_labels = [label_to_index[pathlib.Path(path).parent.name]
                      for path in train_image_paths]
valid_image_labels = [label_to_index[pathlib.Path(path).parent.name]
                      for path in valid_image_paths]

print(train_image_labels[:10])
print(valid_image_labels[:10])

In [None]:
def preprocess_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
#     image = tf.image.resize_with_pad(image, 192, 192)
    image = tf.image.resize(image, [192, 192])
    image /= 255.0
#     image = 2*image-1
    return image

def load_and_preprocess_image(path):
    image = tf.io.read_file(path)

    return preprocess_image(image)

def load_and_preprocess_from_path_label(path, label):
    return load_and_preprocess_image(path), label

In [None]:
# train_ds = tf.data.Dataset.from_tensor_slices( (train_image_paths, train_image_labels))
# valid_ds = tf.data.Dataset.from_tensor_slices( (valid_image_paths, valid_image_labels))

In [None]:
# train_image_label_ds = train_ds.map(load_and_preprocess_from_path_label)
# valid_image_label_ds = valid_ds.map(load_and_preprocess_from_path_label)
# train_image_label_ds

In [None]:
train_path_ds = tf.data.Dataset.from_tensor_slices(train_image_paths)
valid_path_ds = tf.data.Dataset.from_tensor_slices(valid_image_paths)

train_image_ds = train_path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
valid_image_ds = valid_path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)

In [None]:
plt.figure(figsize=(8,8))
for n,image in enumerate(train_image_ds.take(4)):
    plt.subplot(2,2,n+1)
    plt.imshow(image)
    plt.grid(False)
    print(image.shape)
    plt.xticks([])
    plt.yticks([])
#     plt.xlabel(caption_image(train_image_paths[n]))
    plt.show()

In [None]:
train_label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(train_image_labels, tf.int64))
valid_label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(valid_image_labels, tf.int64))



In [None]:
train_image_label_ds = tf.data.Dataset.zip((train_image_ds, train_label_ds))
valid_image_label_ds = tf.data.Dataset.zip((valid_image_ds, valid_label_ds))

In [None]:
BATCH_SIZE = 32
train_ds = train_image_label_ds.shuffle(buffer_size=len(train_image_paths)).batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
valid_ds = valid_image_label_ds.shuffle(buffer_size=len(valid_image_paths)).batch(BATCH_SIZE)


In [None]:
mobile_net = tf.keras.applications.MobileNetV2(input_shape=(192, 192, 3), include_top=False, weights='imagenet')
mobile_net.trainable = False


In [None]:
def change_range(image, label):
    return 2*image-1, label

train_ds = train_ds.map(change_range)
valid_ds = valid_ds.map(change_range)

In [None]:
# model = tf.keras.Sequential([
#     mobile_net(),
#     tf.keras.layers.GlobalAveragePooling2D(),
#     tf.keras.layers.Dropout(0.2),
#     tf.keras.layers.Dense(2),
# ])
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
prediction_layer = tf.keras.layers.Dense(1)

inputs = tf.keras.Input(shape=(192,192,3))
x = mobile_net(inputs, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.25)(x)
outputs = prediction_layer(x)

model = tf.keras.Model(inputs, outputs)

In [None]:
# base_learning_rate = 0.001
base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [None]:
steps_per_epoch=tf.math.ceil(len(train_image_paths)/BATCH_SIZE).numpy()
model.fit(train_ds, epochs=1, steps_per_epoch=steps_per_epoch)


In [None]:
history = model.fit(train_ds, epochs=20, validation_data=valid_ds)

In [None]:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

In [None]:
model.save('transfer.h5')

In [None]:
mobile_net.trainable = True
for layer in mobile_net.layers[:100]:
    layer.trainable = False
    

In [None]:
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),
              metrics=['accuracy'])

In [None]:
model.summary()

In [None]:
fine_tune_epochs = 10
total_epochs = 10 + fine_tune_epochs

history_fine = model.fit(train_ds, 
                        epochs=total_epochs,
                        initial_epoch=history.epoch[-1],
                        validation_data=valid_ds)

In [None]:
acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']

loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']
initial_epochs=10

In [None]:
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs-1,initial_epochs-1],
          plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],
         plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

In [None]:
model_path = "~/ML/tfjs/dalek_or_cyberman/modeltf/tuned_model.h5"
model_path = "transfer.h5"

model = keras.models.load_model(model_path)

In [None]:
impath = '~/Downloads/dalek.jpg'
impath = '~/Downloads/images-3.jpeg'
# impath = './images/valid/dalek/d_000.jpg'
impath = './images/train/cyberman/cyberman_000.jpg'

size = (192, 192)
# size = (160, 160)

im = cv2.imread(impath)
im = resize_with_pad(im, size)
# im = cv2.resize(im, size)
plt.imshow(im)

im = im.reshape(1, *size, 3)
# np.expand_dims(im, axis=)
im.shape
im = im/255
im = im*2-1
predictions = model.predict(im)

predictions = tf.nn.sigmoid(predictions)

# predictions = tf.where(predictions < 0.5, 0, 1)

# predictions.numpy()[0][0]
print(predictions.numpy()[0][0])

In [None]:
def padding(im):
    h, w, _ = im.shape
    if h == w:
        return im
    elif h > w:
        pad = int((h-w)/2)
        im = cv2.copyMakeBorder(im, 0, 0, pad, pad, cv2.BORDER_CONSTANT, (0, 0, 0))
        return im
    elif w > h:
        pad = int((w-h)/2)
        im = cv2.copyMakeBorder(im,  pad, pad, 0, 0, cv2.BORDER_CONSTANT, (0, 0, 0))
        
        return im
def resize_with_pad(im, size):
    im = padding(im)
    im = cv2.resize(im, size)
    return im