In [None]:
import tensorflow as tf
import keras
import os
import core_values as cova

In [None]:
gpus = tf.config.list_physical_devices('GPU')
print(gpus)

try:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

except Exception as e:
    print(f"{e}")

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

In [None]:
path = "/mnt/d/Tojo Sir - Project/"
processed_path = os.path.join(path, "processed_data")

In [None]:
try:
    train_ds = tf.data.Dataset.load(os.path.join(processed_path, "train"))
    valid_ds = tf.data.Dataset.load(os.path.join(processed_path, "valid"))

except Exception as e:
    print(f"{e}")

In [None]:
train = train_ds.prefetch(buffer_size = AUTOTUNE)
valid = valid_ds.prefetch(buffer_size = AUTOTUNE)

In [None]:
early_stopping = keras.callbacks.EarlyStopping(
    monitor = 'val_loss',
    patience = 5,
    verbose = 1,
    restore_best_weights = True
)

reduce_lr = keras.callbacks.ReduceLROnPlateau(
    monitor = 'val_loss',
    factor = 0.5,
    patience = 3,
    verbose = 1,
    min_lr = 0.00000001
)

terminate_nan = keras.callbacks.TerminateOnNaN()

In [None]:
base_model = keras.applications.ResNet152V2(
    include_top = False,
    weights="imagenet",
    classes = 4,
    input_shape=(cova.IMAGE_SIZE[0], cova.IMAGE_SIZE[1], 3)
)

base_model.trainable = False

In [None]:
model = keras.models.Sequential([
    keras.layers.Input((cova.IMAGE_SIZE[0], cova.IMAGE_SIZE[1], 3)),

    base_model,

    keras.layers.Rescaling(scale=2, offset=-1), 

    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(128, activation = "leaky_relu"),
    keras.layers.Dense(4, activation = "softmax")
])

In [None]:
model.compile(optimizer = 'AdamW',
              loss = keras.losses.SparseCategoricalCrossentropy(),
              metrics = ['accuracy'],
              steps_per_execution = 5)

In [None]:
model.fit(train, validation_data = valid, epochs = 256, callbacks = [reduce_lr, terminate_nan, early_stopping])

In [None]:
model.save("resnet152v2-trained-model.keras", overwrite = True)