Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADD vit #84

Open
kyakuno opened this issue May 29, 2024 · 6 comments
Open

ADD vit #84

kyakuno opened this issue May 29, 2024 · 6 comments
Assignees

Comments

@kyakuno
Copy link
Contributor

kyakuno commented May 29, 2024

下記のvitをtfliteに変換する。
https://github.com/taki0112/vit-tensorflow

@Kitazume-Ax
Copy link

作業ブランチ:
https://github.com/axinc-ai/ailia-models-tflite/tree/kitazume/add_vision_transformer
TensorFlow Liteを使用(--tflite)し、floatのモデルを使用(--float)している状態。

@Kitazume-Ax
Copy link

下記のPRを作成済み。
#85

@Kitazume-Ax
Copy link

Kitazume-Ax commented Jun 25, 2024

変換に使用した ViT TensorFlow 実装

下記のリポジトリをgit clone
Vision Transformer in TensorFlow 2.x
https://github.com/hrithickcodes/vision_transformer_tf

READMEに書かれている requirements.txt を使用したインストールは失敗するので、動作するバージョンのパッケージを指定してインストールする。

pip install tensorflow==2.6.5
pip install matplotlib==3.6.3
pip install contourpy==1.1.1
pip install numpy==1.19.5
pip install pyyaml

TensorFlow Liteに変換するため tf.Erf が含まれないように、gelu を approximate=True にする。
vision_transformer_tf/layers/pwffn.py : Line 15

self.gelu = tf.keras.layers.Lambda(lambda x: tf.keras.activations.gelu(x, approximate=True))

vit_architectures.yaml は 224x224 に変更。

ViT-BASE16: 
    encoder_layers: 12
    patch_embedding_dim: 768
    units_in_mlp: 3072
    attention_heads: 12
    image_size: [224, 224, 3]
    patch_size: 16
    dropout_rate: 0.1
    classes: 1000
    class_activation: "sigmoid"

@Kitazume-Ax
Copy link

tf_flowers の学習

元実装のImageNet学習済みウェイトはURLが404で失われている。
下記のようなコードを作成して学習を実行。

import os
import numpy as np
from vit import viT
import tensorflow as tf
import tensorflow_datasets as tfds
from utils.loss import vit_loss
from utils.plots import plot_accuracy, plot_loss

image_size = 224
batch_size = 64
auto = tf.data.AUTOTUNE
resize_bigger = 256
num_classes = 5

learning_rate = 0.001
momentum = 0.9
global_clipnorm = 1.0
vit_config = "vit_architectures.yaml"
epochs = 30
validation_batch_size = 16

def preprocess_dataset(is_training=True):
    def _pp(image, label):
        if is_training:
            image = tf.image.resize(image, (resize_bigger, resize_bigger))
            image = tf.image.random_crop(image, (image_size, image_size, 3))
            image = tf.image.random_flip_left_right(image)
        else:
            image = tf.image.resize(image, (image_size, image_size))
        image = image / 127.5 - 1.0
        return image, label

    return _pp

def prepare_dataset(dataset, is_training=True):
    if is_training:
        dataset = dataset.shuffle(batch_size * 10)
    dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=auto)
    return dataset.batch(batch_size).prefetch(auto)

train_dataset, val_dataset = tfds.load("tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True)
train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)

vit = viT(vit_size="ViT-BASE16", num_classes=num_classes, config_path=vit_config)

optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=momentum, global_clipnorm=global_clipnorm)

chekpoint = tf.keras.callbacks.ModelCheckpoint(os.path.join("training_weights", f"ViT-BASE16_tf_flowers"),
    monitor="val_acc", save_best_only=True, save_weights_only=True)

vit.compile(optimizer=optimizer, loss=vit_loss, metrics=["acc"])

history = vit.fit(train_dataset,
                  validation_data=val_dataset,
                  shuffle=True,
                  validation_batch_size=validation_batch_size,
                  callbacks=[chekpoint],
                  epochs=epochs)

@Kitazume-Ax
Copy link

Kitazume-Ax commented Jun 25, 2024

tflite変換

保存した Checkpoint を使用した tflite 変換は下記のように実行可能。

import os
from vit import viT
import tensorflow as tf 
from utils.general import load_config

num_classes = 5
VIT_CONFIG = load_config("vit_architectures.yaml")

model = viT("ViT-BASE16", num_classes)
model.load_weights(os.path.join("training_weights", "ViT-BASE16_tf_flowers")).expect_partial()
model.compute_output_shape(input_shape = [1] + VIT_CONFIG["ViT-BASE16"]["image_size"])

converter = tf.lite.TFLiteConverter.from_keras_model(model)
vit_tflite = converter.convert()
open("vision_transformer_float.tflite", "wb").write(vit_tflite)

@Kitazume-Ax
Copy link

int8 量子化 tflite 変換

学習に使用している tf_flowers から100イメージを代表データセットとして量子化を実行。
※ PCの実行時間の関係で100にしている

import os
from vit import viT
import tensorflow as tf 
from utils.general import load_config
import tensorflow_datasets as tfds

num_classes = 5
VIT_CONFIG = load_config("vit_architectures.yaml")

model = viT("ViT-BASE16", num_classes)
model.load_weights(os.path.join("training_weights", "ViT-BASE16_tf_flowers")).expect_partial()
model.compute_output_shape(input_shape = [1] + VIT_CONFIG["ViT-BASE16"]["image_size"])

model.summary()
print(os.linesep)

ds = tfds.load("tf_flowers", as_supervised=True)
ds_train = ds["train"]
print(ds_train.cardinality().numpy())

image_size = 224
def preprocess_dataset():
    def _pp(image, label):
        image = tf.image.resize(image, (image_size, image_size))
        image = image / 127.5 - 1.0
        return image, label

    return _pp

def prepare_dataset(dataset):
    return dataset.map(preprocess_dataset(), num_parallel_calls=tf.data.AUTOTUNE)

pp_ds = prepare_dataset(ds_train)

def representative_data_gen():
  for data in pp_ds.batch(1).take(100):
    yield [data[0]]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
vit_tflite = converter.convert()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants