In [1]:
import efficientnet.keras as efn
import tensorflow as tf

from keras.applications import VGG16
from keras import layers, Model, Sequential
from keras.optimizers import RMSprop
from keras.callbacks import LearningRateScheduler

import numpy as np

Using TensorFlow backend.


In [2]:
!nvidia-smi # about GPU

Tue Jun 23 18:55:04 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.67       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Quadro P6000        On   | 00000000:00:05.0 Off |                  Off |
| 26%   53C    P0    72W / 250W |      1MiB / 24449MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

In [3]:
from tensorflow.python.client import device_lib

In [4]:
device_lib.list_local_devices() # available devices

[name: "/device:CPU:0"
 device_type: "CPU"
 memory_limit: 268435456
 locality {
 }
 incarnation: 7855734116297642167, name: "/device:XLA_CPU:0"
 device_type: "XLA_CPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 191168330205273856
 physical_device_desc: "device: XLA_CPU device", name: "/device:XLA_GPU:0"
 device_type: "XLA_GPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 18437516517916742948
 physical_device_desc: "device: XLA_GPU device", name: "/device:GPU:0"
 device_type: "GPU"
 memory_limit: 24199030375
 locality {
   bus_id: 1
   links {
   }
 }
 incarnation: 8585583398489412047
 physical_device_desc: "device: 0, name: Quadro P6000, pci bus id: 0000:00:05.0, compute capability: 6.1"]

In [5]:
INPUT_SHAPE = (240, 240, 3) # columns, rows, depth

In [6]:
NUM_CLASSES = 7 # multi-class classification

In [7]:
efn_base = efn.EfficientNetB3(
    weights='imagenet',
    include_top=False,
    classes=NUM_CLASSES,
    input_shape=INPUT_SHAPE
)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [8]:
# efn_base.summary()

In [9]:
# Brand new FC head
model = Sequential()

model.add(efn_base)
model.add(layers.GlobalMaxPooling2D())
model.add(layers.Dropout(0.2))
model.add(layers.Dense(NUM_CLASSES, activation='softmax'))

In [10]:
# model.summary()

In [11]:
# Freezing base layers
for layer in efn_base.layers:
    layer.trainable = False

In [12]:
from h5imagegenerator import HDF5ImageGenerator

In [13]:
from albumentations import (
    Compose, ShiftScaleRotate, HorizontalFlip
)

In [14]:
# Data augmentation
aug = Compose([
    ShiftScaleRotate(
        shift_limit=0.0225,
        scale_limit=0.06,
        rotate_limit=30
    ),
])

In [15]:
# Generators
train_gen = HDF5ImageGenerator(
    src= './train.h5',
    scaler=False,
    labels_encoding='smooth',
    num_classes=7,
    batch_size=32)

val_gen = HDF5ImageGenerator(
    src= './val.h5',
    scaler=False,
    labels_encoding='smooth',
    num_classes=7,
    batch_size=32)

In [16]:
def step_decay(epoch):
    factor    = 0.5
    drop_rate = 5

    return float(0.001 * (factor ** np.floor((1 + epoch) / drop_rate)))

In [17]:
rms = RMSprop(learning_rate=2e-5)

model.compile(
    loss='categorical_crossentropy',
    metrics=['accuracy'],
    optimizer=rms
)

h = model.fit(
    train_gen,
    validation_data=val_gen,
    callbacks=[LearningRateScheduler(step_decay)],
    verbose=1,
    # use_multiprocessing=True,
    # workers=6,
    epochs=30
)


Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
 25/206 [==>...........................] - ETA: 1:29 - loss: 1.3567 - accuracy: 0.5714

KeyboardInterrupt: 