# Preparation

In [1]:
import tensorflow as tf
import tensorflowjs as tfjs

In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

## Data

In [3]:
NUM_CLASSES = 10
BATCH_SIZE = 32
IMG_SIZE = (128, 128)

In [4]:
data = tf.keras.utils.image_dataset_from_directory(
    'data',
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=True
)

Found 28138 files belonging to 10 classes.


In [5]:
train_size = int(0.8 * len(data))
test_size = int(0.1 * len(data))
val_size = int(0.1 * len(data))

train_data = data.take(train_size)
test_data = data.skip(train_size).take(test_size)
val_data = data.skip(train_size + test_size).take(val_size)

In [6]:
def augment(images, labels):
    augmentation = tf.keras.Sequential([
        tf.keras.layers.RandomFlip("horizontal"),
        tf.keras.layers.RandomRotation(0.2),
        tf.keras.layers.RandomZoom(0.2),
    ])
    images = augmentation(images)
    return images, labels

def preprocess(images, labels):
    images = tf.cast(images, tf.float32) / 255.0
    labels = tf.one_hot(labels, NUM_CLASSES)
    return images, labels

train_data = train_data.map(augment)
train_data = train_data.map(preprocess)

test_data = test_data.map(preprocess)

val_data = val_data.map(preprocess)

# Models

## MobileNetV2

In [7]:
base_model = tf.keras.applications.MobileNetV2(
    input_shape=IMG_SIZE + (3,),
    include_top=False,
    weights='imagenet'
)

for layer in base_model.layers[-30:]:
    layer.trainable = True

model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Dense(1024, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dropout(0.4),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=3,
        restore_best_weights=True
    ),
    tf.keras.callbacks.ModelCheckpoint(
        filepath='models/MobileNetV2.keras',
        monitor='val_accuracy',
        save_best_only=True
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.2,
        patience=2
    ),
    tf.keras.callbacks.TensorBoard(log_dir="logs")
]

### Training

In [8]:
history = model.fit(
    train_data,
    epochs=10,
    validation_data=val_data,
    callbacks=callbacks
)

test_loss, test_accuracy = model.evaluate(test_data)
print(f"\nTest accuracy: {test_accuracy:.4f}")

Epoch 1/10
Epoch 2/10
Epoch 3/10

KeyboardInterrupt: 

In [5]:
model = tf.keras.models.load_model('models/MobileNetV2.keras')
tfjs.converters.save_keras_model(model, 'frontend/public/models/MobileNetV2_tfjs')