
# 簡介 #

本專案將使用分布式運算訓練DL(有使用TPU)，對104種花卉圖片進行分類。

In [1]:
import seaborn as sns

import matplotlib.pyplot as plt
from matplotlib import cm
import math, re, os
import pandas as pd
import numpy as np
import random
import plotly.express as px

import tensorflow as tf
print("Tensorflow version " + tf.__version__)

# 定義分布運算策略 
TPU= 8 GPU

In [2]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu) 
    tf.tpu.experimental.initialize_tpu_system(tpu) 
    strategy = tf.distribute.experimental.TPUStrategy(tpu) 
else:
    strategy = tf.distribute.get_strategy() 

print("運算單元數: ", strategy.num_replicas_in_sync)

# 參數設定

In [3]:
IMAGE_SIZE = [512, 512]
BATCH_SIZE = 16 #此為每個分布單元之batchsize
EPOCHS = 20
AUTO = tf.data.experimental.AUTOTUNE

MODEL_NAME = ['EfficientNetB7']
CLASSES = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 10個
           'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', 
           'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         
           'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           
           'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      
           'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    
           'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            
           'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             
           'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            
           'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        
           'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']     

# 從GCS載入數據

為了使用TPU，則必須把資料上傳至Google雲端數據桶，方便TPU即時擷取資料不延遲。

In [4]:
from kaggle_datasets import KaggleDatasets

gcs_ds_path = KaggleDatasets().get_gcs_path('tpu-getting-started')
gcs_path = gcs_ds_path + f'/tfrecords-jpeg-{IMAGE_SIZE[0]}x{IMAGE_SIZE[1]}'
 

training_filenames = tf.io.gfile.glob(gcs_path + '/train/*.tfrec')
validation_filenames = tf.io.gfile.glob(gcs_path + '/val/*.tfrec')
test_filenames = tf.io.gfile.glob(gcs_path + '/test/*.tfrec')                          

## 匯入額外資料集

In [5]:
gcs_ds_path_EXT = KaggleDatasets().get_gcs_path('tf-flower-photo-tfrec')

gcs_path_SELECT_EXT = {
    192: '/tfrecords-jpeg-192x192',
    224: '/tfrecords-jpeg-224x224',
    331: '/tfrecords-jpeg-331x331',
    512: '/tfrecords-jpeg-512x512'
}
gcs_path_EXT = gcs_path_SELECT_EXT[IMAGE_SIZE[0]]

imageNet_files = tf.io.gfile.glob(gcs_ds_path_EXT + '/imagenet' + gcs_path_EXT + '/*.tfrec')
inatureList_files = tf.io.gfile.glob(gcs_ds_path_EXT + '/inaturalist' + gcs_path_EXT + '/*.tfrec')
openImage_files = tf.io.gfile.glob(gcs_ds_path_EXT + '/openimage' + gcs_path_EXT + '/*.tfrec')
oxford_files = tf.io.gfile.glob(gcs_ds_path_EXT + '/oxford_102' + gcs_path_EXT + '/*.tfrec')
tensorflow_files = tf.io.gfile.glob(gcs_ds_path_EXT + '/tf_flowers' + gcs_path_EXT + '/*.tfrec')

additional_training_filenames = imageNet_files + inatureList_files + openImage_files + oxford_files + tensorflow_files  

training_filenames = training_filenames + additional_training_filenames

# 數據讀取function


In [6]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # 標準化
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "class": tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['class'], tf.int32)
    return image, label

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), 
        "id": tf.io.FixedLenFeature([], tf.string), 
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum

def load_dataset(filenames, labeled=True, ordered=False):

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # 增加運算速度

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
    dataset = dataset.with_options(ignore_order) 
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO)
    # returns (image, label) 或是 (image, id)
    return dataset

# 數據增強

In [7]:
SEED = 2020  #數據增強引入隨機性

#隨機遮蔽
def random_blockout(img, sl=0.1, sh=0.2, rl=0.4):
    p=random.random()
    if p>=0.25:
        w, h, c = IMAGE_SIZE[0], IMAGE_SIZE[1], 3
        origin_area = tf.cast(h*w, tf.float32)

        e_size_l = tf.cast(tf.round(tf.sqrt(origin_area * sl * rl)), tf.int32)
        e_size_h = tf.cast(tf.round(tf.sqrt(origin_area * sh / rl)), tf.int32)

        e_height_h = tf.minimum(e_size_h, h)
        e_width_h = tf.minimum(e_size_h, w)

        erase_height = tf.random.uniform(shape=[], minval=e_size_l, maxval=e_height_h, dtype=tf.int32)
        erase_width = tf.random.uniform(shape=[], minval=e_size_l, maxval=e_width_h, dtype=tf.int32)

        erase_area = tf.zeros(shape=[erase_height, erase_width, c])
        erase_area = tf.cast(erase_area, tf.uint8)

        pad_h = h - erase_height
        pad_top = tf.random.uniform(shape=[], minval=0, maxval=pad_h, dtype=tf.int32)
        pad_bottom = pad_h - pad_top

        pad_w = w - erase_width
        pad_left = tf.random.uniform(shape=[], minval=0, maxval=pad_w, dtype=tf.int32)
        pad_right = pad_w - pad_left

        erase_mask = tf.pad([erase_area], [[0,0],[pad_top, pad_bottom], [pad_left, pad_right], [0,0]], constant_values=1)
        erase_mask = tf.squeeze(erase_mask, axis=0)
        erased_img = tf.multiply(tf.cast(img,tf.float32), tf.cast(erase_mask, tf.float32))

        return tf.cast(erased_img, img.dtype)
    else:
        return tf.cast(img, img.dtype)

    
def data_augment_v2(image, label):
    
    flag = random.randint(1,3)
    coef_1 = random.randint(60, 80) * 0.01
    coef_2 = random.randint(60, 80) * 0.01
    
    if flag == 1:
        image = tf.image.random_flip_left_right(image, seed=SEED)
    elif flag == 2:
        image = tf.image.random_flip_up_down(image, seed=SEED)
    else:
        image = tf.image.random_crop(image, [int(IMAGE_SIZE[0]*coef_1), int(IMAGE_SIZE[0]*coef_2), 3],seed=SEED)
        
    image = random_blockout(image)
    
    return image, label 

# 數據增強v3

In [8]:
import tensorflow_addons as tfa

def data_augment_v3(image, label):
    seed = 100
    
    image = tf.image.resize(image, [720, 720])
    image = tf.image.random_crop(image, [512, 512, 3], seed = seed)

    image = tf.image.random_brightness(image, 0.6, seed = seed)
    
    image = tf.image.random_saturation(image, 3, 5, seed = seed)
        
    image = tf.image.random_contrast(image, 0.3, 0.5, seed = seed)
    
    image = tfa.image.mean_filter2d(image, filter_shape = 10)
    
    image = tf.image.random_flip_left_right(image, seed = seed)
    image = tf.image.random_flip_up_down(image, seed = seed)
    
    return image, label

# 建立數據Pipeline

In [9]:
def data_augment(image, label):
    image = tf.image.random_flip_left_right(image)
    return image, label   

def get_training_dataset():
    dataset = load_dataset(training_filenames, labeled=True)
    dataset = dataset.map(data_augment, num_parallel_calls=AUTO) # 引用數據增強
    dataset = dataset.repeat()
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTO) 
    return dataset

def get_validation_dataset(ordered=False):
    dataset = load_dataset(validation_filenames, labeled=True, ordered=ordered)
    dataset = dataset.batch(batch_size)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO)
    return dataset

def get_test_dataset(ordered=False):
    dataset = load_dataset(test_filenames, labeled=False, ordered=ordered)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTO)
    return dataset

def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

num_training_images = count_data_items(training_filenames)
num_validation_images = count_data_items(validation_filenames)
num_test_images = count_data_items(test_filenames)
print('{}筆train, {}筆val, {}筆未標註的test'.format(num_training_images, num_validation_images, num_test_images))


In [10]:
strategy.num_replicas_in_sync
batch_size = BATCH_SIZE * strategy.num_replicas_in_sync # batch_size = BATCH_SIZE乘以分布數量

ds_train = get_training_dataset()
ds_valid = get_validation_dataset()
ds_test = get_test_dataset()

print("Train:", ds_train)
print ("Val:", ds_valid)
print("Test:", ds_test)

In [11]:
#preview
np.set_printoptions(threshold=15, linewidth=80)

print("Training data shapes:")
for image, label in ds_train.take(3):
    print(image.numpy().shape, label.numpy().shape) #See Note 3.1 above 😀
print("Training data label examples:", label.numpy())

In [12]:
print("Test data shapes:")
for image, idnum in ds_test.take(3):
    print(image.numpy().shape, idnum.numpy().shape) 
print("Test data IDs:", idnum.numpy().astype('U'))

# EDA

In [13]:
# 計算類別權重
from collections import Counter
import gc
gc.enable()

def get_training_dataset_raw():
    dataset = load_dataset(training_filenames, labeled = True, ordered = False)
    return dataset

raw_training_dataset = get_training_dataset_raw()

label_counter = Counter()
for images, labels in raw_training_dataset:
    label_counter.update([labels.numpy()])

del raw_training_dataset    

TARGET_NUM_PER_CLASS = 122 

def get_weight_for_class(class_id):
    counting = label_counter[class_id]
    weight = TARGET_NUM_PER_CLASS / counting
    return weight

weight_per_class = {class_id: get_weight_for_class(class_id) for class_id in range(104)}

In [14]:
data = pd.DataFrame.from_dict(weight_per_class, orient='index', columns=['class_weight'])
plt.figure(figsize=(30, 9))

#barplot color based on value
bplot = sns.barplot(x=data.index, y='class_weight', data=data, palette= cm.Blues(data['class_weight']*0.15));
for p in bplot.patches:
    bplot.annotate(format(p.get_height(), '.1f'), 
                    (p.get_x() + p.get_width() / 2., p.get_height()), 
                    ha = 'center', va = 'center', 
                    xytext = (0, 9), 
                    textcoords = 'offset points')
plt.xlabel("Class", size=14)
plt.ylabel("Class weight (inverse of %)", size=14)

In [15]:
from matplotlib import pyplot as plt

def batch_to_numpy_images_and_labels(data):
    images, labels = data
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    if numpy_labels.dtype == object: # binary string in this case,these are image ID strings
        numpy_labels = [None for _ in enumerate(numpy_images)]
        # If no labels, only image IDs, return None for labels (this is the case for test data)
    return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return CLASSES[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(CLASSES[label], 
                                'OK' if correct else 'NO', 
                                u"\u2192" if not correct else '',
                                CLASSES[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False, titlesize=16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)
    
def display_batch_of_images(databatch, predictions=None, display_mismatches_only=False):
    """This will work with:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
    """
    # data
    images, labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]
        
    # auto-squaring: this will drop data that does not fit into square or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows
        
    # size and spacing
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))
    
    # display
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        title = '' if label is None else CLASSES[label]
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i], label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        if display_mismatches_only:
            if predictions[i] != label:
                subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
        else:        
            subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
    
    #layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()


def display_training_curves(training, validation, title, subplot):
    if subplot%10==1: # set up the subplots on the first call
        plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title('model '+ title)
    ax.set_ylabel(title)
    #ax.set_ylim(0.28,1.05)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.'])

def display_training_curves_v2(training, validation, learning_rate_list, title, subplot):
    if subplot%10==1: # set up the subplots on the first call
        plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title('model '+ title)
    ax.set_ylabel(title, color='b')
    #ax.set_ylim(0.28,1.05)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.', 'learning rate'])        
    
    ax2 = ax.twinx()
    ax2.plot(learning_rate_list, 'g-')
    ax2.set_ylabel('learning rate', color='g')

In [16]:
ds_iter = iter(ds_train.unbatch().batch(20))

In [17]:
one_batch = next(ds_iter)
display_batch_of_images(one_batch)

## preview數據增強v1

In [18]:
row = 2
col = 3
size = 10

all_elements = get_training_dataset().unbatch()
one_element = tf.data.Dataset.from_tensors(next(iter(all_elements)))
augmented_element = one_element.repeat().map(data_augment).batch(row * col)

for (img, label) in augmented_element:
    plt.figure(figsize = (size, int(size * row / col)))
    for j in range(row * col):
        plt.subplot(row, col, j + 1)
        plt.axis('off')
        plt.imshow(img[j, ])
    plt.show()
    break

## 數據增強v2

In [19]:
augmented_element = one_element.repeat().map(data_augment_v2).batch(row * col)

for (img, label) in augmented_element:
    plt.figure(figsize = (size, int(size * row / col)))
    for j in range(row * col):
        plt.subplot(row, col, j + 1)
        plt.axis('off')
        plt.imshow(img[j, ])
    plt.show()
    break

## 數據增強v3

In [20]:
augmented_element = one_element.repeat().map(data_augment_v3).batch(row * col)

for (img, label) in augmented_element:
    plt.figure(figsize = (size, int(size * row / col)))
    for j in range(row * col):
        plt.subplot(row, col, j + 1)
        plt.axis('off')
        plt.imshow(img[j, ])
    plt.show()
    break

# 建模

## 先定義模型架構

In [21]:
print('pre-train list:\n', ', '.join(tf.keras.applications.__dir__()))

In [22]:
use_efficientnet = True 
if use_efficientnet:
    !pip install -q efficientnet
    from efficientnet.tfkeras import EfficientNetB7

In [23]:
#string as code
# model_command = f'tf.keras.applications.{MODEL_NAME[0]}'
# pretrained_model = eval(model_command)

with strategy.scope():
    #pretrained_model = tf.keras.applications.DenseNet201
    #pretrained_model = tf.keras.applications.NASNetMobile
    #pretrained_model = tf.keras.applications.ResNet101V2
    #pretrained_model = tf.keras.applications.MobileNetV2
    #pretrained_model = EfficientNetB7

    pretrained_model = EfficientNetB7(include_top=False, weights='imagenet', input_shape=[*IMAGE_SIZE, 3])
    pretrained_model.trainable = True

    model = tf.keras.Sequential([
        pretrained_model,
        tf.keras.layers.GlobalAveragePooling2D(), 
        tf.keras.layers.Dense(len(CLASSES), activation='softmax')
        ],
        name= MODEL_NAME[0]
    )

In [24]:
model.compile(
    optimizer='nadam',
    loss = 'sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy'],
)

In [25]:
model.summary()

In [26]:
tf.keras.utils.plot_model(model, show_shapes=True)

## 定義callback

In [27]:
checkpoint_filepath = f"Flowers-Classification-{model.name}.h5"

checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=True
)

In [28]:
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)

In [29]:
def exponential_lr(epoch, start_lr = 0.00001, min_lr = 0.00001, max_lr = 0.00005 * strategy.num_replicas_in_sync,
                    rampup_epochs = 5, sustain_epochs = 0,
                    exp_decay = 0.75): 

    def lr(epoch, start_lr, min_lr, max_lr, rampup_epochs, sustain_epochs, exp_decay):
        if epoch < rampup_epochs:
            lr = ((max_lr - start_lr) / rampup_epochs * epoch + start_lr)
        elif epoch < rampup_epochs + sustain_epochs:
            lr = max_lr
        else:
            lr = ((max_lr - min_lr) * exp_decay**(epoch - rampup_epochs - sustain_epochs) + min_lr)
        return lr
    
    return lr(epoch, start_lr, min_lr, max_lr, rampup_epochs, sustain_epochs, exp_decay)

lr_callback = tf.keras.callbacks.LearningRateScheduler(exponential_lr, verbose=True)


rng = [i for i in range(EPOCHS)]
y = [exponential_lr(x) for x in rng]
plt.plot(rng, y)
print("Learning rate curve: {:.3g} to {:.3g} to {:.3g}".format(y[0], max(y), y[-1]))

# 訓練

## Fit Model

In [30]:
history = model.fit(
    ds_train,
    validation_data=ds_valid,
    epochs=EPOCHS,
    steps_per_epoch=num_training_images / batch_size,
    callbacks=[lr_callback, checkpoint], 
    class_weight = weight_per_class
)

## Show result

In [31]:
def display_training_curves(training, validation, lr, title, subplot):
    if subplot%10==1: # set up the subplots on the first call
        plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title('model '+ title)
    ax.set_ylabel(title)
    #ax.set_ylim(0.28,1.05)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.'])

In [32]:
display_training_curves( 
    history.history['loss'],
    history.history['val_loss'],
    history.history['lr'],
    'loss',
    211,
)

display_training_curves(
    history.history['sparse_categorical_accuracy'],
    history.history['val_sparse_categorical_accuracy'],
    history.history['lr'],
    'accuracy',
    212,
)

In [33]:
zoom_after = 15
display_training_curves(
    history.history['loss'][zoom_after:],
    history.history['val_loss'][zoom_after:],
    history.history['lr'],
    'loss',
    211,
)

display_training_curves(
    history.history['sparse_categorical_accuracy'][zoom_after:],
    history.history['val_sparse_categorical_accuracy'][zoom_after:],
    history.history['lr'],
    'accuracy',
    212,
)

In [34]:
model.load_weights(checkpoint_filepath)

In [35]:
model.summary()

In [36]:
#convert to tf-lite model
'''
print(checkpoint_filepath)
tflite_model_name = checkpoint_filepath.replace('.h5', '.tflite')
tflite_model_name
'''

'''
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Save the model
with open(tflite_model_name, 'wb') as f:
    f.write(tflite_model)
    
print('TFLiteConversion completed successfully \U0001F680')  
'''

## ensemble

In [37]:
if len(MODEL_NAME)>1:
    using_ensemble_models = True
else:
    using_ensemble_models = False

In [38]:
def get_pretrained_model(model_name, image_dataset_weights, trainable=True):
    pretrained_model= model_name(
        include_top=False ,
        weights=image_dataset_weights,
        input_shape=[*IMAGE_SIZE, 3]
    )

    pretrained_model.trainable = trainable
    
    model = tf.keras.Sequential([
        pretrained_model, 
        tf.keras.layers.GlobalAveragePooling2D(), 
        tf.keras.layers.Dense(len(CLASSES), activation='softmax')
    ])
    
    return model

In [39]:
if using_ensemble_models:
    with strategy.scope():
        model_EB7 = get_pretrained_model(EfficientNetB7, 'noisy-student', trainable=True)

    model_EB7.load_weights(f'../input/models/Flowers-Classification-{model.name}.h5')    

In [40]:
if using_ensemble_models:
    model_EB7.summary()

In [41]:
if using_ensemble_models:
    with strategy.scope():
        model_D201 = get_pretrained_model(tf.keras.applications.DenseNet201, 'imagenet', trainable=True)

    model_D201.load_weights('../input/models/Flowers-Classification-DenseNet201.h5')  

In [42]:
if using_ensemble_models:
    model_D201.summary()

## Ensemble both models

In [43]:
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix

In [44]:
if using_ensemble_models:
    cmdataset = get_validation_dataset(ordered=True) # since we are splitting the dataset and iterating separately on images and labels, order matters.
    images_ds = cmdataset.map(lambda image, label: image)
    labels_ds = cmdataset.map(lambda image, label: label).unbatch()
    cm_correct_labels = next(iter(labels_ds.batch(num_validation_images))).numpy() # get everything as one batch

    m1 = model_EB7.predict(images_ds)
    m2 = model_D201.predict(images_ds)

    scores = []
    for alpha in np.linspace(0,1,100):
        cm_probabilities = alpha*m1+(1-alpha)*m2
        cm_predictions = np.argmax(cm_probabilities, axis=-1)
        scores.append(f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro'))

    print("Correct   labels: ", cm_correct_labels.shape, cm_correct_labels)
    print("Predicted labels: ", cm_predictions.shape, cm_predictions)
    plt.plot(scores)

    best_alpha = np.argmax(scores)/100
    cm_probabilities = best_alpha*m1+(1-best_alpha)*m2
    cm_predictions = np.argmax(cm_probabilities, axis=-1)

    #best_alpha = 0.35

In [45]:
if using_ensemble_models:
    print(best_alpha, max(scores))

In [46]:
if using_ensemble_models:
    test_ds = get_test_dataset(ordered=True)
    #best_alpha = 0.35

    print('Computing predictions...')
    test_images_ds = test_ds.map(lambda image, idnum: image)
    probabilities1 = model_EB7.predict(test_images_ds)
    probabilities2 = model_D201.predict(test_images_ds)

    probabilities = best_alpha * probabilities1 + (1 - best_alpha) * probabilities2

    predictions = np.argmax(probabilities, axis=-1)
    print(predictions)

    print('Generating submission.csv file...')
    # Get image ids from test set and convert to unicode
    test_ids_ds = test_ds.map(lambda image, idnum: idnum).unbatch()
    test_ids = next(iter(test_ids_ds.batch(num_test_images))).numpy().astype('U')

    # Write the submission file
    np.savetxt(
        'submission.csv',
        np.rec.fromarrays([test_ids, predictions]),
        fmt=['%s', '%d'],
        delimiter=',',
        header='id,label',
        comments='',
    )

    # Look at the first few predictions
    !head submission.csv

# 評估模型

In [47]:
def display_confusion_matrix(cmat, score, precision, recall):
    plt.figure(figsize=(25,25))
    ax = plt.gca()
    ax.matshow(cmat, cmap='Reds')
    ax.set_xticks(range(len(CLASSES)))
    ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
    ax.set_yticks(range(len(CLASSES)))
    ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    titlestring = ""
    if score is not None:
        titlestring += 'f1 = {:.3f} '.format(score)
    if precision is not None:
        titlestring += '\nprecision = {:.3f} '.format(precision)
    if recall is not None:
        titlestring += '\nrecall = {:.3f} '.format(recall)
    if len(titlestring) > 0:
        ax.text(101, 1, titlestring, fontdict={'fontsize': 18, 'horizontalalignment':'right', 'verticalalignment':'top', 'color':'#804040'})
    
    if not using_ensemble_models:
        print('Epoch with min loss and max accuracy:', np.argmin(history.history['val_loss']), np.argmax(history.history['val_sparse_categorical_accuracy']))
        print('min loss and max accuracy:', round(min(history.history['val_loss']),2), round(max(history.history['val_sparse_categorical_accuracy']),2))

    print(titlestring.replace('\n', ''))
    plt.show()
    
def display_training_curves(training, validation, title, subplot):
    if subplot%10==1: 
        plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title('model '+ title)
    ax.set_ylabel(title)
    #ax.set_ylim(0.28,1.05)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.'])

In [48]:
# Confusion Matrix
cmdataset = get_validation_dataset(ordered=True)
images_ds = cmdataset.map(lambda image, label: image)
labels_ds = cmdataset.map(lambda image, label: label).unbatch()

cm_correct_labels = next(iter(labels_ds.batch(num_validation_images))).numpy()

if using_ensemble_models:
    print('using_ensemble_models')
    probabilities1 = model_EB7.predict(images_ds)
    probabilities2 = model_D201.predict(images_ds)
    cm_probabilities = best_alpha * probabilities1 + (1 - best_alpha) * probabilities2
else:
    cm_probabilities = model.predict(images_ds)
    
cm_predictions = np.argmax(cm_probabilities, axis=-1)

labels = range(len(CLASSES))
cmat = confusion_matrix(
    cm_correct_labels,
    cm_predictions,
    labels=labels,
)
cmat = (cmat.T / cmat.sum(axis=1)).T

In [49]:
cmat

In [50]:
score = f1_score(
    cm_correct_labels,
    cm_predictions,
    labels=labels,
    average='macro',
)

precision = precision_score(
    cm_correct_labels,
    cm_predictions,
    labels=labels,
    average='macro',
)

recall = recall_score(
    cm_correct_labels,
    cm_predictions,
    labels=labels,
    average='macro',
)

display_confusion_matrix(cmat, score, precision, recall)

# 視覺化驗證

In [51]:
dataset = get_validation_dataset()
dataset = dataset.unbatch().batch(20)
batch = iter(dataset)

In [52]:
images, labels = next(batch)

In [53]:
if using_ensemble_models:
    probabilities1 = model_EB7.predict(images)
    probabilities2 = model_D201.predict(images)
    probabilities = best_alpha * probabilities1 + (1 - best_alpha) * probabilities2
else:
    probabilities = model.predict(images)

In [54]:
predictions = np.argmax(probabilities, axis=-1)
display_batch_of_images((images, labels), predictions)

## 檢查 validation set 上的錯誤

In [55]:
mismatches = sum(cm_predictions!=cm_correct_labels)
print('validation data上的錯誤: {} / {} ({:.2%})'.format(mismatches, num_validation_images, mismatches/num_validation_images))

In [56]:
cmdataset = get_validation_dataset(ordered=True)
images_ds = cmdataset.map(lambda image, label: image)
labels_ds = cmdataset.map(lambda image, label: label).unbatch()
cm_correct_labels = next(iter(labels_ds.batch(num_validation_images))).numpy() # get everything as one batch

mismatches_images, mismatches_predictions, mismatches_labels = [], [], []
mismatches_dataset = tf.data.Dataset.from_tensors([])
val_batch = iter(cmdataset.unbatch().batch(1))

for image_index in range(num_validation_images):
    batch = next(val_batch)
    if cm_predictions[image_index] != cm_correct_labels[image_index]:
        print('Predicted vs Correct labels: {}, {}'.format(cm_predictions[image_index], cm_correct_labels[image_index]))
        #display_batch_of_images(batch, np.array([cm_predictions[image_index]]))
        #mismatches_dataset = tf.data.Dataset.from_tensors(batch)
        #mismatches_images.append(tf.data.Dataset.from_tensors(batch))
        #mismatches_predictions.append(cm_predictions[image_index])
        #mismatches_labels.append(cm_correct_labels[image_index])

In [57]:
dataset = get_validation_dataset()
dataset = dataset.unbatch().batch(20)
batch = iter(dataset)
images, labels = next(batch)

In [58]:
for i in range(3):
    display_batch_of_images((images, labels), predictions, display_mismatches_only=True)
    images, labels = next(batch)

In [59]:
one_batch = next(ds_iter)
display_batch_of_images(one_batch)

# 預測test set

## Test Time Augmentation (TTA) 

In [60]:
using_tta = False
tta_iterations = 3

In [61]:
if using_tta:
    def get_test_dataset(ordered=False):
        dataset = load_dataset(test_filenames, labeled=False, ordered=ordered)
        dataset = dataset.map(data_augment, num_parallel_calls=AUTO) #tuning4
        #dataset = dataset.map(data_augment_v2, num_parallel_calls=AUTO)
        #dataset = dataset.map(data_augment_v3, num_parallel_calls=AUTO)
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(AUTO)
        return dataset

In [62]:
def predict_tta(model, tta_iterations):
    probs  = []
    for i in range(tta_iterations):
        print('TTA iteration ', i)
        test_ds = get_test_dataset(ordered=True) # since we are splitting the dataset and iterating separately on images and ids, order matters.
        test_images_ds = test_ds.map(lambda image, idnum: image)
        
        if using_ensemble_models:
            print('using_ensemble_models')
            probabilities1 = model_EB7.predict(test_images_ds)
            probabilities2 = model_D201.predict(test_images_ds)
            probabilities = best_alpha * probabilities1 + (1 - best_alpha) * probabilities2
            probs.append(probabilities)
        else:
            probs.append(model.predict(test_images_ds,verbose=0))
        
    return probs

In [63]:
test_ds = get_test_dataset(ordered=True)
test_images_ds = test_ds.map(lambda image, idnum: image)

if using_tta:
    print('Predictions using TTA...')
    probabilities = np.mean(predict_tta(model, tta_iterations), axis=0)
else:
    print('Predictions...')
    probabilities = model.predict(test_images_ds)
predictions = np.argmax(probabilities, axis=-1)
print(predictions)

In [64]:
print('using_ensemble_models:', using_ensemble_models)
print('Generating submission.csv file...')

# Get image ids from test set and convert to unicode
test_ids_ds = test_ds.map(lambda image, idnum: idnum).unbatch()
test_ids = next(iter(test_ids_ds.batch(num_test_images))).numpy().astype('U')

# Write the submission file
np.savetxt(
    'submission.csv',
    np.rec.fromarrays([test_ids, predictions]),
    fmt=['%s', '%d'],
    delimiter=',',
    header='id,label',
    comments='',
)

# Look at the first few predictions
!head submission.csv