In [None]:
# !pip install tensorflow==3.2.0
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm 
import random, time, os
from sklearn.model_selection import train_test_split
import pandas as pd

if not os.path.isdir("/gdrive"):
    from google.colab import drive
    drive.mount("/gdrive")

!test -d data && ls -l data/ || unzip /gdrive/MyDrive/dataset/shopee-code-league-2020-product-detection.zip -d data 1>/dev/null

tf.__version__
# Tutorial
# https://www.kaggle.com/fadheladlansyah/product-detection-effnetb5-aug-tta

Mounted at /gdrive


'2.4.1'

## 1. Input

In [None]:
def get_data():
    train =  tf.keras.preprocessing.image_dataset_from_directory(
        "data/resized/train/",
        validation_split = .1,
        subset = "training",
        seed = 1,
        labels     = "inferred",
        label_mode = "int",
        image_size = (299, 299)
    )

    validation =  tf.keras.preprocessing.image_dataset_from_directory(
        "data/resized/train/",
        validation_split = .1,
        subset = "training",
        seed = 1,
        labels     = "inferred",
        label_mode = "int",
        image_size = (299, 299)
    )
    return train, validation

## 2. Layers and Models

### 2.1 Preprocessing input

In [None]:
class Preprocess(tf.keras.layers.Layer):
    def __init__(self):
        super(Preprocess, self).__init__()
    
    def call(self, X):
        X = tf.keras.applications.efficientnet.preprocess_input(X)
        return X

### 2.2. Augmentation layer

In [None]:
class Augmentation(tf.keras.layers.Layer):
    def __init__(self):
        super(Augmentation, self).__init__()
        # self.cutout = CutOut(.1)
    
    def call(self, X):
        X = tf.image.random_flip_left_right(X)
        X = tf.image.random_brightness(X, max_delta=0.5)
        X = tf.image.random_contrast(X, lower=0.75, upper=1.2)
        # X = self.cutout(X)
        return X

### 2.3 Model

In [None]:
def get_model():
    with tf.device("/device:GPU:0"):
        # Load efficient net
        base = tf.keras.applications.efficientnet.EfficientNetB5(
            include_top = False,
            weights="imagenet",
            pooling = None
        )
        for layer in base.layers[:-10]:
            layer.trainable = False

        net = tf.keras.models.Sequential([
            Preprocess(),
            Augmentation(),
            base,
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Dense(42, activation="softmax")
        ])

        net.compile(
            optimizer = tf.keras.optimizers.SGD(learning_rate=1e-6),
            loss = "sparse_categorical_crossentropy",
            metrics = "accuracy"
        )

        return net

## 3. Training loop

### 3.1 Callbacks

In [None]:
# Reduce learning rate on plateau
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss', factor=0.1, patience=2, verbose=1,
    mode='auto', min_delta=0.0001, cooldown=0, min_lr=1e-5
)

# save checkspoint
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath="/gdrive/MyDrive/dataset/checkpoints/efficient_net",
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

# Early stopping
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0.002, patience=3, verbose=1)

### 3.2 Fit

In [None]:
# config
BATCH_SIZE = 256
EPOCH = 5
AUTO = tf.data.experimental.AUTOTUNE

In [None]:
#input:
train, validation = get_data()

Found 105392 files belonging to 42 classes.
Using 94853 files for training.
Found 105392 files belonging to 42 classes.
Using 94853 files for training.


In [None]:
net = get_model()

try: 
    net.load_weights("/gdrive/MyDrive/dataset/checkpoints/efficient_net")
    print("Loaded weight from last check points")

except Exception as e:
    print("Check point not found", e)

# Train
history = net.fit(
    train,
    validation_data = validation,
    epochs= EPOCH,
    callbacks=[reduce_lr, model_checkpoint_callback],
    workers=AUTO    
)

Loaded weight from last check points
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
