In [1]:
import numpy as np  # noqa
import pandas as pd
import argparse
import tensorflow as tf
from tqdm.auto import tqdm

from tensorflow.keras import layers as L
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import Sequence
from tensorflow.keras.backend import clear_session
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.metrics import categorical_accuracy, top_k_categorical_accuracy

import efficientnet.tfkeras as efn

import cv2
import os

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [3]:
models=["B5-25.h5","B5-45.h5","B5-60.h5","B6-41.h5","B7-30.h5","B7-9.h5","B6-21.h5","B7-29.h5"]

In [4]:
def one_hot(image, label):
  label = tf.one_hot(label, 1292)
  return image, label

def read_tfrecords(example, input_size):
  features = {
      'img': tf.io.FixedLenFeature([], tf.string),
      'image_id': tf.io.FixedLenFeature([], tf.int64),
      'grapheme_root': tf.io.FixedLenFeature([], tf.int64),
      'vowel_diacritic': tf.io.FixedLenFeature([], tf.int64),
      'consonant_diacritic': tf.io.FixedLenFeature([], tf.int64),
      'unique_tuple': tf.io.FixedLenFeature([], tf.int64),
  }
  example = tf.io.parse_single_example(example, features)
  img = tf.image.decode_image(example['img'])
  img = tf.reshape(img, input_size + (1, ))
  img = tf.cast(img, tf.float32)
  # grayscale -> RGB
  img = tf.repeat(img, 3, -1)

  # image_id = tf.cast(example['image_id'], tf.int32)
  # grapheme_root = tf.cast(example['grapheme_root'], tf.int32)
  # vowel_diacritic = tf.cast(example['vowel_diacritic'], tf.int32)
  # consonant_diacritic = tf.cast(example['consonant_diacritic'], tf.int32)
  unique_tuple = tf.cast(example['unique_tuple'], tf.int32)
  return img, unique_tuple

In [5]:
parser = argparse.ArgumentParser()
parser.add_argument('--model_id', type=int, default=0)
parser.add_argument('--seed', type=int, default=123)
parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--input_size', type=str, default='160,256')
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--epochs', type=int, default=60)
parser.add_argument('--backbone', type=str, default='efficientnet-b5')
parser.add_argument('--weights', type=str, default='imagenet')
args, _ = parser.parse_known_args()

args.input_size = tuple(int(x) for x in args.input_size.split(','))
AUTO = tf.data.experimental.AUTOTUNE

val_fns = tf.io.gfile.glob('./records/val*.tfrec')
val_ds = tf.data.TFRecordDataset(val_fns, num_parallel_reads=AUTO)
val_ds = val_ds.map(lambda e: read_tfrecords(e, args.input_size), num_parallel_calls=AUTO)
val_ds = val_ds.batch(args.batch_size)
val_ds = val_ds.map(one_hot, num_parallel_calls=AUTO)

In [6]:
for modelName in tqdm(models):
    print(modelName)
    model = load_model(modelName)
    model.compile(optimizer=Adam(lr=args.lr), loss=categorical_crossentropy, metrics=[categorical_accuracy, top_k_categorical_accuracy])
    model.evaluate(val_ds)
    del model
    clear_session()
    print()

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

B5-25.h5
   1256/Unknown - 99s 79ms/step - loss: 0.0977 - categorical_accuracy: 0.9750 - top_k_categorical_accuracy: 0.9980
B5-45.h5
   1256/Unknown - 98s 78ms/step - loss: 0.0679 - categorical_accuracy: 0.9825 - top_k_categorical_accuracy: 0.9987
B5-60.h5
   1256/Unknown - 96s 76ms/step - loss: 0.0700 - categorical_accuracy: 0.9818 - top_k_categorical_accuracy: 0.9988
B6-41.h5
   1256/Unknown - 124s 99ms/step - loss: 0.0862 - categorical_accuracy: 0.9792 - top_k_categorical_accuracy: 0.9985
B7-30.h5
   1256/Unknown - 164s 131ms/step - loss: 0.0683 - categorical_accuracy: 0.9828 - top_k_categorical_accuracy: 0.9987
B7-9.h5
   1256/Unknown - 166s 132ms/step - loss: 0.0938 - categorical_accuracy: 0.9755 - top_k_categorical_accuracy: 0.9978
B6-21.h5
   1256/Unknown - 124s 98ms/step - loss: 0.0953 - categorical_accuracy: 0.9763 - top_k_categorical_accuracy: 0.9978
B7-29.h5
   1256/Unknown - 97s 77ms/step - loss: 0.0727 - categorical_accuracy: 0.9811 - top_k_categorical_accuracy: 0.9987
B7-