In [None]:
# 參考 https://github.com/sayakpaul/Sharpness-Aware-Minimization-TensorFlow

In [None]:
!pip install tensorflow_addons
!pip install focal_loss

In [None]:
# 因使用TPU模型必須需放在gcp storage上，這步驟需要給google colab存取gcp storage的權限
from google.colab import auth
import os
auth.authenticate_user()
project_id = 'intrepid-vista-285204' # 需依照自己的project name命名
!gcloud config set project {project_id}

In [None]:
from google.colab import drive
drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/T-Brain') #更改路徑

In [None]:
import os
import PIL
import PIL.Image
import pickle
import os
import re

import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.callbacks import LearningRateScheduler
from sklearn.utils import class_weight
from focal_loss import SparseCategoricalFocalLoss

import autoaugment

tf.config.set_soft_device_placement(True)

In [None]:
try:
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
    tf.config.experimental_connect_to_cluster(resolver)
    # This is the TPU initialization code that has to be at the beginning.
    tf.tpu.experimental.initialize_tpu_system(resolver)
    print("All devices: ", tf.config.list_logical_devices('TPU'))
    # strategy = tf.distribute.experimental.TPUStrategy(resolver)
    strategy = tf.distribute.TPUStrategy(resolver)
except ValueError:
    print("Not connected to a TPU runtime. Using CPU/GPU strategy")
    strategy = tf.distribute.MirroredStrategy()    

TF records

In [None]:
# Load Dataset

data_set_name = ["test_cv2", "test_ori"]

label_info_path = [f'gs://esun--2021/tf_records/{name}/' for name in data_set_name]
train_path = [f'gs://esun--2021/tf_records/{name}/train/*' for name in data_set_name]
val_path = [f'gs://esun--2021/tf_records/{name}/val/*' for name in data_set_name]

In [None]:
print(label_info_path)
print(train_path)
print(val_path)

In [None]:
# 
def read_record(example):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(example, features)
    
    image = tf.image.decode_jpeg(example["image"], channels=3)
    image = tf.reshape(image, [224,224,3])
    label = tf.cast(example["label"], tf.int32)
    
    return image, label

In [None]:
AUTO = tf.data.experimental.AUTOTUNE

def prepare_dataset(file_path, order = False):
    filenames = []
    for f_p in file_path:
        filenames.extend(tf.io.gfile.glob(f_p))   
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
    
    # disregard the order of .tfrec files
    ignore_order = tf.data.Options()
    if order == False:
        ignore_order.experimental_deterministic = False
    else:
        ignore_order.experimental_deterministic = True
    dataset = dataset.with_options(ignore_order)
    
    dataset = dataset.map(read_record, num_parallel_calls=AUTO)
        
    return dataset

train_dataset = prepare_dataset(train_path)
val_dataset = prepare_dataset(val_path)

In [None]:
# batch_size is scaled with the number of TPU cores
batch_size = 64 * strategy.num_replicas_in_sync

def auto_aug(image, label):
    image = autoaugment.distort_image(image, aug_name='ra_aa', ra_num_layers=1, ra_magnitude=5)
    return image, label
train_dataset = train_dataset.map(auto_aug, num_parallel_calls=AUTO)
train_dataset = train_dataset.shuffle(16384).repeat() \
                .batch(batch_size, drop_remainder=True).prefetch(256)
val_dataset = val_dataset.batch(batch_size, drop_remainder=True).prefetch(256)

In [None]:
image, label = next(iter(train_dataset))

fig, axes = plt.subplots(constrained_layout = True, nrows=3, ncols=3, figsize=(10, 10))

for i in range(3):
    for j in range(3):
        axes[i][j].imshow(image[i*3+j], aspect="auto")
        axes[i][j].axis("off")

In [None]:
image, label = next(iter(val_dataset))

fig, axes = plt.subplots(constrained_layout = True, nrows=3, ncols=3, figsize=(10, 10))

for i in range(3):
    for j in range(3):
        axes[i][j].imshow(image[i*3+j], aspect="auto")
        axes[i][j].axis("off")

In [None]:
class SAMModel(tf.keras.Model):
    def __init__(self, model, rho=0.05):
        """
        p, q = 2 for optimal results as suggested in the paper
        (Section 2)
        """
        super(SAMModel, self).__init__()
        self.model = model
        self.rho = rho

    def train_step(self, data):
        (images, labels) = data
        e_ws = []
        with tf.GradientTape() as tape:
            predictions = self.model(images)
            loss = self.compiled_loss(labels, predictions)
        trainable_params = self.model.trainable_variables
        gradients = tape.gradient(loss, trainable_params)
        grad_norm = self._grad_norm(gradients)
        scale = self.rho / (grad_norm + 1e-12)
        
        with tf.GradientTape() as tape:
            predictions = self.model(images)
            loss = self.compiled_loss(labels, predictions)    
        for (grad, param) in zip(gradients, trainable_params):
            e_w = grad * scale
            param.assign_add(e_w)
            e_ws.append(e_w)
        sam_gradients = tape.gradient(loss, trainable_params)
        for (param, e_w) in zip(trainable_params, e_ws):
            param.assign_sub(e_w)
        
        self.optimizer.apply_gradients(
            zip(sam_gradients, trainable_params))

        self.compiled_metrics.update_state(labels, predictions)
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        (images, labels) = data
        predictions = self.model(images, training=False)
        loss = self.compiled_loss(labels, predictions)
        self.compiled_metrics.update_state(labels, predictions)
        return {m.name: m.result() for m in self.metrics}

    def _grad_norm(self, gradients):
        norm = tf.norm(
            tf.stack([
                tf.norm(grad) for grad in gradients if grad is not None
            ])
        )
        return norm

    def call(self, x):
        return self.model(x)

In [None]:
def get_model():

    inputs = tf.keras.layers.Input(shape=(224, 224, 3))
    efficient_model = EfficientNetB0(include_top=True, weights=None, classes=801, input_tensor=inputs)
    model = tf.keras.Model(inputs, efficient_model.outputs)

    return model

with strategy.scope():
    
    model = SAMModel(get_model())
    model.compile(optimizer=tf.keras.optimizers.Adam(), loss=SparseCategoricalFocalLoss(gamma=2), metrics=["sparse_categorical_accuracy"])

In [None]:
# Learning rate scheduler
def decay(inp):   
 
    lr_init = 0.00005
    # max learning rate is scaled with the number of TPU cores
    lr_max = 0.000125 * strategy.num_replicas_in_sync
    lin_lr = 5
    if inp <= lin_lr:
        lr = inp*(lr_max - lr_init) / lin_lr + lr_init
    else:
        lr = lr_max * np.exp(-0.1*(inp - lin_lr))
    lr = lr * 0.1
    return lr

lrs = LearningRateScheduler(decay)

In [None]:
import json
import os
label_cnt_dict = {}
for file_path in label_info_path:
  d = tf.io.read_file(filename=os.path.join(file_path, 'label_cnt_dict.json'))

  tmp_dict = json.loads(str(d.numpy())[2:-1])
  print(file_path)
  print(len(tmp_dict))
  for k, v in tmp_dict.items():
    k = int(k)
    if k not in label_cnt_dict:
      label_cnt_dict[k] = v
    else:
      label_cnt_dict[k] += v
print(label_cnt_dict)

In [None]:

label_cnt_lst = []
for k, v in label_cnt_dict.items():
  label_cnt_lst.extend([k]*v)
class_weights = class_weight.compute_class_weight(
          'balanced',
          np.unique(label_cnt_lst), 
          label_cnt_lst)


In [None]:
class_weight = {idx:weight for idx, weight in enumerate(class_weights)}


In [None]:
print(len(label_cnt_lst))

In [None]:
len(class_weight)

In [None]:
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, verbose=0, mode='auto', restore_best_weights=True)

ckpt = tf.keras.callbacks.ModelCheckpoint('./service/fine_tuned_model/efficient_SAM_weights.h5', monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=True, mode='auto')

history = model.fit(train_dataset,
            validation_data=val_dataset,
            steps_per_epoch=len(label_cnt_lst)//batch_size,
            epochs=10,
            callbacks=[lrs, ckpt, early_stop])


In [None]:
model(tf.zeros((1,224,224,3)))
model.load_weights('./service/fine_tuned_model/efficient_SAM_weights.h5')
model.model.save('./service/fine_tuned_model/efficient_SAM.hdf5')