# Importing necessary Libraries

In [None]:
import os
import shutil
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from datasets import load_from_disk
from IPython.display import Image
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.layers import (
    BatchNormalization,
    Conv2D,
    Dense,
    Dropout,
    GlobalAveragePooling2D,
    Input,
    InputLayer,
    MaxPool2D,
)
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import BinaryAccuracy, Precision, Recall
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import L2
from tensorflow.keras.utils import plot_model
from wandb.keras import WandbCallback

import wandb

In [None]:
# ! pip install wandb
# ! pip install pydot
# ! pip install graphviz

PROJECT_SUFFIX = "5_datasets_v2"
ENTITY = "makersplace"
PROJECT = f"ai-or-not-{PROJECT_SUFFIX}"
SEED = 77
RUNTIME_DATE_SUFFIX = "%m%d_%H%M"

# current time
JOB_TYPE_SUFFIX = f"{PROJECT_SUFFIX}_cat"
RUN_NAME_SUFFIX = datetime.now().strftime(RUNTIME_DATE_SUFFIX)


# Datasets Paths
training_dataset_path = "../cache/data/training_dataset"
validation_dataset_path = "../cache/data/validation_dataset"

# Model Paths
cnn_model_path = Path(f"../cache/models/{JOB_TYPE_SUFFIX}/cnn_{RUN_NAME_SUFFIX}")
effv2_model_dir_path = Path(f"../cache/models/{JOB_TYPE_SUFFIX}/en2s_{RUN_NAME_SUFFIX}")


# Deleted and recreated training and validation dataset folders
CLEAN_RUN = True


np.random.seed(SEED)
tf.random.set_seed(SEED)

# ENVIRONMENT VARIABLES
os.environ["WANDB_API_KEY"] = "d13afab09b400fc9d606e612d806a4b0740790fd"

# DataSet Creation

#### Importing datasets and storing them in train and validation dataset

In [None]:
dire_train_imagenet_directory = "../cache/data/DIRE/train/imagenet"
dire_train_celebahq_directory = "../cache/data/DIRE/train/celebahq"
dire_train_lsun_bedroom_directory = "../cache/data/DIRE/train/lsun_bedroom"

cifake_train_directory = "../cache/data/cifake/train"

dire_validation_imagenet_directory = "../cache/data/DIRE/val/imagenet/"
dire_validation_lsun_bedroom_directory = "../cache/data/DIRE/val/lsun_bedroom/"

DIRE_IMAGENET_CLASS_NAMES = ["real", "adm"]
DIRE_CELEBAHQ_CLASS_NAMES = ["real", "sdv2"]
DIRE_LSUN_BEDROOM_CLASS_NAMES = ["real", "adm", "iddpm", "pndm", "stylegan"]
CIFAKE_CLASS_NAMES = ["REAL", "FAKE"]

CLASS_NAMES = ["REAL", "FAKE"]

NUM_SHARDS = 5
SHARD_INDEX = 1
IMAGE_RESOLUTION = 128
EPOCHS = 10

CONFIGURATION = {
    "BATCH_SIZE": 32,
    "IM_SIZE": IMAGE_RESOLUTION,
    "DROPOUT_RATE": 0.1,
    "N_EPOCHS": EPOCHS,
    "REGULARIZATION_RATE": 0.001,
    "N_FILTERS": 6,
    "KERNEL_SIZE": 3,
    "N_STRIDES": 1,
    "POOL_SIZE": 2,
    "N_DENSE_1": 1024,
    "N_DENSE_2": 128,
    "NUM_CLASSES": 2,
    "LEARNING_RATE": 0.0001,
}

In [None]:
def resize_image(image):
    image = tf.image.resize_with_pad(
        image=image,
        target_height=CONFIGURATION["IM_SIZE"],
        target_width=CONFIGURATION["IM_SIZE"],
    )
    # divide by 255 to normalize
    image = image / 255.0
    return image


def decode_img(img):
    img = tf.io.decode_image(img, channels=3)
    return resize_image(img)


def process_path(file_path):
    # Load the raw data from the file as a string
    img = tf.io.read_file(file_path)
    img = decode_img(img)
    return img


def print_dataset_summary(dataset, directory):
    print(f"************ Dataset summary ************")
    print(f"Directory: {directory}")
    print(f"Class names: {dataset.class_names}")
    print(f"File count: {len(dataset.file_paths)}")


def visualize_dataset(samples):
    plt.figure(figsize=(12, 12))
    index = 1
    for image, label in samples:
        plt.subplot(4, 4, index)
        plt.imshow(image)
        title = CLASS_NAMES[int(label)]
        plt.title(title)
        plt.axis("off")
        index += 1

    plt.show()


def get_custom_dataset(directory, label, pattern):
    # if directory path contains 'aiornot' load it as tf dataset else load it as a custom dataset
    if "aiornot" in directory:
        read_aiornot = load_from_disk(dataset_path=directory)
        dataset = read_aiornot.to_tf_dataset(
            columns="image",
            label_cols="label",
        )
        dataset = dataset.map(lambda x, y: (tf.cast(x, tf.float32), tf.cast(y, tf.float32)))
        dataset = dataset.map(lambda x, y: (resize_image(x), y), num_parallel_calls=tf.data.AUTOTUNE)

    else:
        list_ds = tf.data.Dataset.list_files(str(Path(directory) / pattern), shuffle=True)
        dataset = list_ds.map(process_path, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.map(lambda x: (x, label))

    return dataset

# DataSet Preparation

In [None]:
DEFAULT_SHARDS = 1
REAL_IMAGE_SHARDS = 1
train_directories = [
    # AI Artbench Dataset
    ("../cache/data/ai-artbench/train/real", 0.0, "*/*", REAL_IMAGE_SHARDS, 0),
    ("../cache/data/ai-artbench/train/AI*", 1.0, "*", DEFAULT_SHARDS, 0),
    # CIFAKE Dataset
    ("../cache/data/cifake/train/REAL", 0.0, "*", REAL_IMAGE_SHARDS, 0),
    ("../cache/data/cifake/train/FAKE", 1.0, "*", DEFAULT_SHARDS, 0),
    # AIORNOT Dataset
    ("../cache/data/aiornot/train", 0.0, "*", DEFAULT_SHARDS, 0),
    # FakeImage Dataset
    (
        "../cache/data/FakeImageDataset/ImageData/val/Midjourneyv5-5K/Midjourneyv5-5K_test",
        1.0,
        "*",
        1,
        0,
    ),
    # DIRE Imagenet Dataset
    ("../cache/data/DIRE/train/imagenet/real", 0.0, "*/*", REAL_IMAGE_SHARDS, 0),
    ("../cache/data/DIRE/train/imagenet/adm", 1.0, "*/*", DEFAULT_SHARDS, 0),
    # DIRE CelebaHQ Dataset
    ("../cache/data/DIRE/train/celebahq/real", 0.0, "*", 2, 0),
    ("../cache/data/DIRE/train/celebahq/sdv2", 1.0, "*/*", DEFAULT_SHARDS, 0),
    # DIRE Lsun Bedroom Dataset
    ("../cache/data/DIRE/train/lsun_bedroom/real", 0.0, "*", REAL_IMAGE_SHARDS, 0),
    ("../cache/data/DIRE/train/lsun_bedroom/stylegan", 1.0, "*", DEFAULT_SHARDS, 0),
    # ('../cache/data/DIRE/train/lsun_bedroom/adm', 1.0, '*', 6, 0),
    # ('../cache/data/DIRE/train/lsun_bedroom/iddpm', 1.0, '*', 6, 0),
    # ('../cache/data/DIRE/train/lsun_bedroom/pndm', 1.0, '*', 6, 0),
]

In [None]:
# Test Directories
TEST_SHARDS = 1
test_directories = [
    # AI Artbench Dataset
    ("../cache/data/ai-artbench/test/AI*", 1.0, "*", TEST_SHARDS, 0),  # 675 Batches
    ("../cache/data/ai-artbench/test/real", 0.0, "*/*", TEST_SHARDS, 0),
    # # CIFAKE Dataset
    ("../cache/data/cifake/test/REAL", 0.0, "*", TEST_SHARDS, 0),
    ("../cache/data/cifake/test/FAKE", 1.0, "*", TEST_SHARDS, 0),
    # FakeImageDataset
    (
        "../cache/data/FakeImageDataset/ImageData/val/Midjourneyv5-5K/Midjourneyv5-5K_test",
        1.0,
        "*",
        TEST_SHARDS,
        0,
    ),
    (
        "../cache/data/FakeImageDataset/ImageData/val/SDv15-CC30K/SDv15-CC30K/",
        1.0,
        "*/*",
        TEST_SHARDS,
        0,
    ),  # 991 Batches
    (
        "../cache/data/FakeImageDataset/ImageData/val/SDv21-CC15K/SDv21-CC15K/SDv2-dpmsolver-25-10K",
        1.0,
        "*",
        TEST_SHARDS,
        0,
    ),  # 496 Batches
    (
        "../cache/data/FakeImageDataset/ImageData/val/cogview2-22K/cogview2-22K",
        1.0,
        "*",
        TEST_SHARDS,
        0,
    ),  # 698 Batches
    # DIRE Imagenet Dataset
    ("../cache/data/DIRE/test/imagenet/real", 0.0, "*/*", TEST_SHARDS, 0),
    ("../cache/data/DIRE/test/imagenet/adm", 1.0, "*/*", TEST_SHARDS, 0),
    ("../cache/data/DIRE/test/imagenet/sdv1", 1.0, "*/*", TEST_SHARDS, 0),
    # DIRE Imagenet Dataset
    ("../cache/data/DIRE/test/celebahq/real", 0.0, "*", TEST_SHARDS, 0),
    ("../cache/data/DIRE/test/celebahq/sdv2", 1.0, "*", TEST_SHARDS, 0),
    ("../cache/data/DIRE/test/celebahq/if", 1.0, "*", TEST_SHARDS, 0),
    ("../cache/data/DIRE/test/celebahq/dalle2", 1.0, "*", TEST_SHARDS, 0),
    # DIRE Lsun Bedroom Dataset
    ("../cache/data/DIRE/test/lsun_bedroom/dalle2", 1.0, "*", TEST_SHARDS, 0),
    ("../cache/data/DIRE/test/lsun_bedroom/midjourney", 1.0, "*", TEST_SHARDS, 0),
    ("../cache/data/DIRE/test/lsun_bedroom/sdv1_new", 1.0, "*", TEST_SHARDS, 0),
    ("../cache/data/DIRE/test/lsun_bedroom/sdv2", 1.0, "*", TEST_SHARDS, 0),
    ("../cache/data/DIRE/test/lsun_bedroom/vqdiffusion", 1.0, "*", TEST_SHARDS, 0),
]

In [None]:
# declare empty dataset for trainingn and validation
training_dataset = None
validation_dataset = None

for directory, label, pattern, shards, shard_index in train_directories:
    print(
        f"""
**********************************************************************************************************************************          
          Directory:  {directory} 
**********************************************************************************************************************************
    """
    )

    current_dataset = get_custom_dataset(directory, label, pattern)

    if shards > 1:
        # shard the dataset
        current_dataset = current_dataset.shard(num_shards=shards, index=shard_index)

    # split dataset into train and validation
    dataset_size = len(current_dataset)
    train_size = int(0.8 * dataset_size)
    validation_size = int(0.2 * dataset_size)

    current_train_dataset = current_dataset.skip(validation_size)
    current_validation_dataset = current_dataset.take(validation_size)

    if training_dataset is None:
        training_dataset = current_train_dataset
    else:
        training_dataset = training_dataset.concatenate(current_train_dataset)

    if validation_dataset is None:
        validation_dataset = current_validation_dataset
    else:
        validation_dataset = validation_dataset.concatenate(current_validation_dataset)

    # VISUALIZE DATASET
    visualize_dataset(current_dataset.take(16))

training_dataset = (
    training_dataset.shuffle(buffer_size=training_dataset.cardinality(), seed=SEED)
    .batch(CONFIGURATION["BATCH_SIZE"])
    .prefetch(tf.data.AUTOTUNE)
)

validation_dataset = validation_dataset.batch(CONFIGURATION["BATCH_SIZE"]).prefetch(tf.data.AUTOTUNE)

In [None]:
def log_test_metrics(current_model, wandb_run):
    # declare empty dataframe and add results to it for each test dataset
    loss_dataframe = pd.DataFrame()
    accuracy_dataframe = pd.DataFrame()

    # composite dataframe
    composite_dataframe = pd.DataFrame()

    loss_dataframe["model"] = [wandb_run.name]
    accuracy_dataframe["model"] = [wandb_run.name]
    # instantiate empty dataset
    composite_dataset = None

    for directory, label, pattern, shards, shard_index in test_directories:
        current_dataset = get_custom_dataset(directory, label, pattern)
        # current_dataset.ignore_errors()
        current_dataset = current_dataset.batch(CONFIGURATION["BATCH_SIZE"])

        # evaluate model on current_dataset and capture metrics
        test_loss, test_accuracy, test_precision, test_recall = current_model.evaluate(current_dataset)
        # Extract the 3rd part of the directory path
        dataset_name = directory.split("/")[3] + " : " + directory.split("/")[-1]

        # add a column to the dataframe
        loss_dataframe[dataset_name] = [test_loss]
        accuracy_dataframe[dataset_name] = [test_accuracy]

        # concatenate current dataset to composite dataset
        if composite_dataset is None:
            composite_dataset = current_dataset
        else:
            composite_dataset = composite_dataset.concatenate(current_dataset)

    (
        composite_loss,
        composite_accuracy,
        composite_precision,
        composite_recall,
    ) = current_model.evaluate(composite_dataset)
    composite_dataframe["model"] = [wandb_run.name]
    composite_dataframe["loss"] = [composite_loss]
    composite_dataframe["accuracy"] = [composite_accuracy]
    composite_dataframe["precision"] = [composite_precision]
    composite_dataframe["recall"] = [composite_recall]

    wandb_run.log({"Loss": wandb.Table(dataframe=loss_dataframe)})
    wandb_run.log({"Accuracy": wandb.Table(dataframe=accuracy_dataframe)})
    wandb_run.log({"Overall Results": wandb.Table(dataframe=composite_dataframe)})


# def log_test_metrics(current_model, wandb_run):
#     # declare empty dataframe and add results to it for each test dataset
#     loss_dataframe = pd.DataFrame()
#     accuracy_dataframe = pd.DataFrame()

#     loss_dataframe['model'] = [wandb_run.name]
#     accuracy_dataframe['model'] = [wandb_run.name]

#     for directory, label, pattern, shards, shard_index in test_directories:

#         current_dataset = get_custom_dataset(directory, label, pattern)
#         # current_dataset.ignore_errors()
#         current_dataset = current_dataset.batch(CONFIGURATION["BATCH_SIZE"])

#         # evaluate model on current_dataset and capture metrics
#         test_loss, test_accuracy, test_precision, test_recall = current_model.evaluate(current_dataset)
#         # Extract the 3rd part of the directory path
#         dataset_name = directory.split('/')[3] + ' : ' + directory.split('/')[-1]


#         # add a column to the dataframe
#         loss_dataframe[dataset_name] = [test_loss]
#         accuracy_dataframe[dataset_name] = [test_accuracy]

#     # wandb_run.Table(dataframe = loss_dataframe)
#     # wandb_run.table(dataframe = accuracy_dataframe)

#     wandb_run.log({"Loss": wandb.Table(dataframe = loss_dataframe)})
#     wandb_run.log({"Accuracy": wandb.Table(dataframe = accuracy_dataframe)})

# Data Visualization

In [None]:
plt.figure(figsize=(12, 12))
for image, label in training_dataset.take(1):
    for index in range(16):
        ax = plt.subplot(4, 4, index + 1)
        plt.imshow(image[index])
        plt.title(CLASS_NAMES[int(label[index])])
        plt.axis("off")

In [None]:
if CLEAN_RUN:
    if os.path.exists(training_dataset_path):
        # remove directory recursively
        shutil.rmtree(training_dataset_path)

    if os.path.exists(validation_dataset_path):
        shutil.rmtree(validation_dataset_path)

if not os.path.exists(training_dataset_path):
    training_dataset.save(training_dataset_path)

if not os.path.exists(validation_dataset_path):
    validation_dataset.save(validation_dataset_path)

In [None]:
# load dataset sets from disk
training_dataset = tf.data.Dataset.load(training_dataset_path)
validation_dataset = tf.data.Dataset.load(validation_dataset_path)

# Config and Callbacks for Model

In [None]:
# Reduce LR On no Improvement
reduce_lr = ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.1,
    patience=1,
    verbose=1,
    mode="min",
    min_delta=0.000001,
    cooldown=0,
    min_lr=1e-15,
)

# Early Stopping
early_stopping = EarlyStopping(
    monitor="val_loss",
    min_delta=0.00001,
    patience=5,
    verbose=1,
    mode="min",
    baseline=None,
    restore_best_weights=True,
)

## Wandb
Using wandb to track performance of the model

In [None]:
wandb.login()

# Model Creation

### Basic CNN Model

In [None]:
cnn_model_path.mkdir(parents=True, exist_ok=True)
cnn_model_best_weights_path = cnn_model_path / "best_weights"
cnn_model_saved_model_path = cnn_model_path / "saved_model"

cnn_model = tf.keras.Sequential(
    [
        InputLayer(input_shape=(CONFIGURATION["IM_SIZE"], CONFIGURATION["IM_SIZE"], 3)),
        Conv2D(
            filters=CONFIGURATION["N_FILTERS"],
            kernel_size=CONFIGURATION["KERNEL_SIZE"],
            strides=CONFIGURATION["N_STRIDES"],
            padding="valid",
            activation="relu",
            kernel_regularizer=L2(CONFIGURATION["REGULARIZATION_RATE"]),
        ),
        BatchNormalization(),
        MaxPool2D(pool_size=CONFIGURATION["POOL_SIZE"], strides=CONFIGURATION["N_STRIDES"] * 2),
        Conv2D(
            filters=CONFIGURATION["N_FILTERS"] * 2,
            kernel_size=CONFIGURATION["KERNEL_SIZE"],
            strides=CONFIGURATION["N_STRIDES"],
            padding="valid",
            activation="relu",
            kernel_regularizer=L2(CONFIGURATION["REGULARIZATION_RATE"]),
        ),
        BatchNormalization(),
        MaxPool2D(pool_size=CONFIGURATION["POOL_SIZE"], strides=CONFIGURATION["N_STRIDES"] * 2),
        Conv2D(
            filters=CONFIGURATION["N_FILTERS"] * 3,
            kernel_size=CONFIGURATION["KERNEL_SIZE"],
            strides=CONFIGURATION["N_STRIDES"],
            padding="valid",
            activation="relu",
            kernel_regularizer=L2(CONFIGURATION["REGULARIZATION_RATE"]),
        ),
        BatchNormalization(),
        MaxPool2D(pool_size=CONFIGURATION["POOL_SIZE"], strides=CONFIGURATION["N_STRIDES"] * 2),
        GlobalAveragePooling2D(),
        Dense(
            CONFIGURATION["N_DENSE_1"],
            activation="relu",
            kernel_regularizer=L2(CONFIGURATION["REGULARIZATION_RATE"]),
        ),
        BatchNormalization(),
        Dropout(rate=CONFIGURATION["DROPOUT_RATE"]),
        Dense(
            CONFIGURATION["N_DENSE_2"],
            activation="relu",
            kernel_regularizer=L2(CONFIGURATION["REGULARIZATION_RATE"]),
        ),
        BatchNormalization(),
        Dense(1, activation="sigmoid"),
    ]
)

In [None]:
cnn_model.summary()

In [None]:
plot_model(cnn_model, to_file="cnn_model.png", show_shapes=True, show_layer_names=True)
Image(filename="cnn_model.png")

In [None]:
metrics = [
    BinaryAccuracy(name="Accuracy"),
    Precision(name="Precision"),
    Recall(name="Recall"),
]

In [None]:
cnn_model.compile(
    optimizer=Adam(learning_rate=CONFIGURATION["LEARNING_RATE"]),
    loss=BinaryCrossentropy(),
    metrics=metrics,
)

#### Model Training

In [None]:
wandb_run = wandb.init(
    entity=ENTITY,
    project=PROJECT,
    job_type=f"cnn_{JOB_TYPE_SUFFIX}",
    name=f"cnn_{RUN_NAME_SUFFIX}",
    config=CONFIGURATION,
)

# Model Checkpointing
model_checkpoint = ModelCheckpoint(
    cnn_model_best_weights_path,
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode="min",
    save_freq="epoch",
)

cnn_history = cnn_model.fit(
    training_dataset,
    validation_data=validation_dataset,
    epochs=CONFIGURATION["N_EPOCHS"],
    callbacks=[
        reduce_lr,
        model_checkpoint,
        early_stopping,
        WandbCallback(save_model=False),
    ],
)

#### CNN Model Evaluation

In [None]:
# Load best model weights from checkpoint
cnn_model.load_weights(cnn_model_best_weights_path)
# save cnn model to disk
cnn_model.save(cnn_model_saved_model_path)

In [None]:
cnn_model.evaluate(validation_dataset)

In [None]:
plt.plot(cnn_history.history["loss"])
plt.plot(cnn_history.history["val_loss"])
plt.title("CNN Model Loss")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend(["train_loss", "val_loss"])
plt.show()

In [None]:
plt.plot(cnn_history.history["Accuracy"])
plt.plot(cnn_history.history["val_Accuracy"])
plt.title("CNN Model Accuracy")
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.legend(["train_accuracy", "val_accuracy"])
plt.show()

In [None]:
log_test_metrics(cnn_model, wandb_run)

wandb_run.finish()

## Using Pretrained Models

### EfficientNetV2 S

Model has some issues with checkpoint and model_save_weights type of callbacks,
so inorder for them to work you have to downgrade tensorflow with
**!pip install tensorflow==2.9.1**

In [None]:
effv2_model_dir_path.mkdir(parents=True, exist_ok=True)
effv2_model_best_weights_path = effv2_model_dir_path / "best_weights"
effv2_model_saved_model_path = effv2_model_dir_path / "saved_model"

CONFIGURATION = {
    "BATCH_SIZE": 32,
    "IM_SIZE": IMAGE_RESOLUTION,
    "DROPOUT_RATE": 0.0,
    "N_EPOCHS": EPOCHS,
    "REGULARIZATION_RATE": 0.0,
    "N_FILTERS": 6,
    "KERNEL_SIZE": 3,
    "N_STRIDES": 1,
    "POOL_SIZE": 2,
    "N_DENSE_1": 2048,
    "N_DENSE_2": 1024,
    "N_DENSE_3": 256,
    "NUM_CLASSES": 2,
    "LEARNING_RATE": 0.001,
}

wandb_run = wandb.init(
    entity=ENTITY,
    project=PROJECT,
    job_type=f"effv2s_{JOB_TYPE_SUFFIX}",
    name=f"fc_layers_{RUN_NAME_SUFFIX}",
    reinit=True,
    config=CONFIGURATION,
    settings=wandb.Settings(start_method="fork"),
)

In [None]:
# Model Checkpointing
model_checkpoint = ModelCheckpoint(
    effv2_model_best_weights_path,
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode="min",
    save_freq="epoch",
)

backbone = tf.keras.applications.EfficientNetV2S(
    include_top=False,
    weights="imagenet",
    input_tensor=None,
    input_shape=(CONFIGURATION["IM_SIZE"], CONFIGURATION["IM_SIZE"], 3),
)

In [None]:
backbone.trainable = False

In [None]:
efficientnetv2s_model = tf.keras.Sequential(
    [
        Input(shape=(CONFIGURATION["IM_SIZE"], CONFIGURATION["IM_SIZE"], 3)),
        backbone,
        GlobalAveragePooling2D(),
        Dense(CONFIGURATION["N_DENSE_1"], activation="relu"),
        BatchNormalization(),
        Dropout(rate=CONFIGURATION["DROPOUT_RATE"]),
        Dense(CONFIGURATION["N_DENSE_2"], activation="relu"),
        BatchNormalization(),
        Dropout(rate=CONFIGURATION["DROPOUT_RATE"]),
        Dense(CONFIGURATION["N_DENSE_3"], activation="relu"),
        BatchNormalization(),
        Dense(1, activation="sigmoid"),
    ]
)

efficientnetv2s_model.summary()

In [None]:
plot_model(
    efficientnetv2s_model,
    to_file="efficientnet_b4_model.png",
    show_shapes=True,
    show_layer_names=True,
)
Image(filename="efficientnet_b4_model.png")

In [None]:
efficientnetv2s_model.compile(
    optimizer=Adam(learning_rate=CONFIGURATION["LEARNING_RATE"]),
    loss=BinaryCrossentropy(),
    metrics=metrics,
)

In [None]:
efficientnetv2s_history = efficientnetv2s_model.fit(
    training_dataset,
    validation_data=validation_dataset,
    epochs=CONFIGURATION["N_EPOCHS"],
    callbacks=[
        reduce_lr,
        model_checkpoint,
        early_stopping,
        WandbCallback(save_model=False),
    ],
)

log_test_metrics(efficientnetv2s_model, wandb_run)
wandb_run.finish()

In [None]:
plt.plot(efficientnetv2s_history.history["loss"])
plt.plot(efficientnetv2s_history.history["val_loss"])
plt.title("EfficientNetV2 S Model Loss")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend(["train_loss", "val_loss"])
plt.show()

In [None]:
plt.plot(efficientnetv2s_history.history["Accuracy"])
plt.plot(efficientnetv2s_history.history["val_Accuracy"])
plt.title("EfficientNetV2 S Model Accuracy")
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.legend(["train_accuracy", "val_accuracy"])
plt.show()

#### FineTuning the model

By mistake due to an error I think I lost the pretrained model which had its output layers trained so instead I had to start training the model with all its layers set to trainable=True from start instead of Finetuning as I had used a lot of GPU Time

You can instead remember to downgrade Tensorflow before using this pretrained model

In [None]:
backbone.trainable = True

In [None]:
efficientnetv2s_model.summary()

In [None]:
efficientnetv2s_model.compile(
    optimizer=Adam(learning_rate=CONFIGURATION["LEARNING_RATE"] / 100),
    loss=BinaryCrossentropy(),
    metrics=metrics,
)

In [None]:
wandb_run = wandb.init(
    entity=ENTITY,
    project=PROJECT,
    job_type=f"effv2s_{JOB_TYPE_SUFFIX}",
    name=f"all_layers_{RUN_NAME_SUFFIX}",
    reinit=True,
    config=CONFIGURATION,
)

# Model Checkpointing
model_checkpoint = ModelCheckpoint(
    effv2_model_best_weights_path,
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode="min",
    save_freq="epoch",
)

# efficientnetv2s_model.load_weights(effv2_model_best_weights_path)

efficientnetv2s_history = efficientnetv2s_model.fit(
    training_dataset,
    validation_data=validation_dataset,
    epochs=CONFIGURATION["N_EPOCHS"] * 3,
    callbacks=[
        reduce_lr,
        model_checkpoint,
        early_stopping,
        WandbCallback(save_model=False),
    ],
)

#### EfficientNetB4 Model Evaluation after Finetuning

In [None]:
# Load best model
efficientnetv2s_model.load_weights(effv2_model_best_weights_path)
efficientnetv2s_model.save(effv2_model_saved_model_path)

In [None]:
efficientnetv2s_model.evaluate(validation_dataset)

log_test_metrics(efficientnetv2s_model, wandb_run)
wandb_run.finish()

In [None]:
plt.plot(efficientnetv2s_history.history["loss"])
plt.plot(efficientnetv2s_history.history["val_loss"])
plt.title("EfficientNetV2 S Model Loss")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend(["train_loss", "val_loss"])
plt.show()

In [None]:
plt.plot(efficientnetv2s_history.history["Accuracy"])
plt.plot(efficientnetv2s_history.history["val_Accuracy"])
plt.title("EfficientNetV2 S Model Accuracy")
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.legend(["train_accuracy", "val_accuracy"])
plt.show()