### Imports

In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
from datetime import datetime as dt
import glob

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix

from tensorflow.keras.applications import VGG16, VGG19
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Dense, Input, Dropout, Rescaling, Conv2D, MaxPooling2D, Flatten, Dropout, Activation, BatchNormalization
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.losses import BinaryCrossentropy, SparseCategoricalCrossentropy 
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.regularizers import l2
from tensorflow.keras.metrics import Precision, Recall, F1Score
from tensorflow.keras import regularizers

from sklearn.model_selection import GridSearchCV
import matplotlib.pyplot as plt
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px

import warnings 
warnings.filterwarnings('ignore')

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [4]:
DATASET_FOLDER_TRAIN = '../CIFAKE/train'
DATASET_FOLDER_TEST = '../CIFAKE/test'

BATCH_SIZE = 32
COLOR_MODE = 'rgb'
CLASS_MODE = 'binary'
TARGET_SIZE = (32, 32)
LEARN_RATE = 0.0005
LEARN_RATE = 0.005
SEED = 42

LOSS_FN = BinaryCrossentropy()

N_EPOCHS = 50    

### Train/test sets

In [None]:
train_datagen = ImageDataGenerator(   
    rescale=1./255,  
    # rotation_range=45,
    # width_shift_range=0.5,
    # height_shift_range=0.5,
    # zoom_range=0.5,
    # horizontal_flip=True,  
    # vertical_flip=True,  
    validation_split=0.2,  
)

train_generator = train_datagen.flow_from_directory(
    DATASET_FOLDER_TRAIN,  
    target_size=TARGET_SIZE,    
    color_mode=COLOR_MODE,  
    batch_size=BATCH_SIZE,
    class_mode=CLASS_MODE,  
    subset='training', 
    seed = SEED
)

validation_generator = train_datagen.flow_from_directory(
    DATASET_FOLDER_TRAIN, 
    target_size=TARGET_SIZE,
    color_mode=COLOR_MODE,
    batch_size=BATCH_SIZE,
    class_mode=CLASS_MODE,
    subset='validation', 
    shuffle=False,
)

test_datagen = ImageDataGenerator(rescale=1./255)

test_generator = test_datagen.flow_from_directory(
    DATASET_FOLDER_TEST,  
    target_size=TARGET_SIZE,
    color_mode=COLOR_MODE,
    batch_size=BATCH_SIZE,
    class_mode=CLASS_MODE,
)

### Class Balance

In [None]:
print('Class distribution: ')
print( f'Train REAL images: {len(glob.glob('../CIFAKE/train/REAL/*'))}'  )
print( f'Train FAKE images: {len(glob.glob('../CIFAKE/train/FAKE/*'))}'  )

print( f'Test REAL images: {len(glob.glob('../CIFAKE/test/REAL/*'))}'  )
print( f'Test FAKE images: {len(glob.glob('../CIFAKE/test/FAKE/*'))}'  )

### Transfer Learn model

In [None]:
VGG_base_model = tf.keras.applications.VGG16(
    include_top = False, 
    weights = 'imagenet', 
    input_shape = (32,32, 3),
    pooling = 'max'
)
VGG_base_model.trainable = True

inputs = tf.keras.Input(shape = (32,32, 3))
x = VGG_base_model(inputs, training = False)

x = BatchNormalization(axis = -1, momentum = 0.99, epsilon = 0.001)(x)
x = Dense(256, 
          kernel_regularizer = regularizers.l2(0.01), 
          activity_regularizer = regularizers.l1(0.01), 
          bias_regularizer = regularizers.l1(0.01),
          activation = 'relu')(x)
x = Dropout(rate = .4, seed = 512)(x)       
x = Dense(64, activation = 'relu')(x)

outputs = Dense(1, activation = 'sigmoid')(x)
VGG_model = tf.keras.Model(inputs, outputs)

VGG_model.compile(
    optimizer = tf.keras.optimizers.Adamax(learning_rate = .001),
    loss = tf.keras.losses.BinaryCrossentropy(),
    metrics = ['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()]
)

VGG_model.summary(expand_nested=True)

### Train model

In [None]:
history = VGG_model.fit(train_generator, 
                             callbacks=[EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)],
                             epochs=3, 
                             validation_data=validation_generator,
                             steps_per_epoch=train_generator.samples // BATCH_SIZE,
                             validation_steps=validation_generator.samples // BATCH_SIZE,
                        )

In [9]:
hist = history.history
cols = list(history.history.keys())

### Evaluation

In [None]:
fig_metrics = make_subplots(rows=2, cols=2, subplot_titles=("Loss", f"Precision", "Accuracy", "Recall"), vertical_spacing=0.07)

# Loss
fig_metrics.add_trace( go.Scatter(x=list(range(len(hist['loss']))), y=hist['loss'], mode='lines+markers', name='Train Loss'), row=1, col=1 )
fig_metrics.add_trace( go.Scatter(x=list(range(len(hist['val_loss']))), y=hist['val_loss'], mode='lines+markers', name='Val Loss'), row=1, col=1 )

# Precision
fig_metrics.add_trace(  go.Scatter(x=list(range(len(hist['precision']))), y=hist['precision'],  mode='lines+markers', name=f'Train precision'),  row=1, col=2 )
fig_metrics.add_trace( go.Scatter(x=list(range(len(hist['val_precision']))), y=hist['val_precision'], mode='lines+markers', name=f'Val precision'), row=1, col=2)

# Accuracy
fig_metrics.add_trace(  go.Scatter(x=list(range(len(hist['accuracy']))), y=hist['accuracy'],  mode='lines+markers', name=f'Train accuracy'),  row=2, col=1 )
fig_metrics.add_trace( go.Scatter(x=list(range(len(hist['val_accuracy']))), y=hist['val_accuracy'], mode='lines+markers', name=f'Val accuracy'), row=2, col=1)

# Recall
fig_metrics.add_trace(  go.Scatter(x=list(range(len(hist['recall']))), y=hist['recall'],  mode='lines+markers', name=f'Train recall'),  row=2, col=2 )
fig_metrics.add_trace( go.Scatter(x=list(range(len(hist['val_recall']))), y=hist['val_recall'], mode='lines+markers', name=f'Val recall'), row=2, col=2)

fig_metrics.update_yaxes(title_text="Loss", row=1, col=1)
fig_metrics.update_yaxes(title_text=f"Precision", row=1, col=2)
fig_metrics.update_yaxes(title_text=f"Accuracy", row=2, col=1)
fig_metrics.update_yaxes(title_text=f"Recall", row=2, col=2)

fig_metrics.update_layout(
    showlegend=True,
    margin=dict(l=10, r=10, b=10, t=30),
    width=1400, height=800
)

for annotation in fig_metrics['layout']['annotations']:
    annotation['y'] = annotation['y'] + 0.002


test_loss, test_acc, test_prec, test_recall = VGG_model.evaluate(
    test_generator,
    steps=test_generator.samples // test_generator.batch_size
)

# Confusion matrix
from sklearn.metrics import confusion_matrix
y_true = test_generator.classes
y_pred = VGG_model.predict(test_generator, steps=test_generator.samples // test_generator.batch_size)
# y_pred_classes = np.argmax(y_pred, axis=1)
y_pred_classes = (y_pred > 0.5).astype(int)
y_true = y_true[:len(y_pred_classes)]

cm = confusion_matrix(y_true, y_pred_classes)

class_labels = list(test_generator.class_indices.keys())

# Plotly heatmap for confusion matrix
fig_confMatrix = go.Figure(data=go.Heatmap(
    z=cm,
    x= class_labels ,   # Predicted labels
    y= class_labels,    # True labels
    hoverongaps=False,
    colorscale='Blues',
    showscale=True,
    text=cm,
    texttemplate="%{text}",
    textfont={"size":15}
))

fig_confMatrix.update_layout(
    title='Confusion Matrix',
    xaxis_title='Predicted Label',
    yaxis_title='True Label',
    width=600,
    height=500,
)

# ROC AUC
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc

fpr, tpr, _ = roc_curve(y_true, y_pred_classes)
roc_auc = auc(fpr, tpr)

fig_rocauc = go.Figure()

fig_rocauc.add_trace(go.Scatter(
    x=fpr, y=tpr,
    mode='lines',
    line=dict(color='blue', width=2),
    name=f'ROC curve (AUC = {roc_auc:0.2f})'
))

fig_rocauc.add_trace(go.Scatter(
    x=[0, 1], y=[0, 1],
    mode='lines',
    line=dict(color='black', dash='dash'),
    showlegend=False,
    hoverinfo='skip'
))

fig_rocauc.update_layout(
    title='ROC AUC for Binary Classification',
    xaxis_title='False Positive Rate',
    yaxis_title='True Positive Rate',
    width=700,
    height=600,
    legend=dict(x=0.6, y=0.1),
    margin=dict(l=40, r=40, t=40, b=40),)



# Classification report 
from sklearn.metrics import classification_report

report = classification_report(y_true, y_pred_classes, target_names=class_labels, 
                               zero_division=False,
                               labels = [0, 1])


### Evaluation plots+

In [None]:
fig = make_subplots(
    rows=4, cols=2, 
    subplot_titles=("", "",
                    "Loss", "Precision", 
                    "Accuracy", 'Recall',
                    'Confusion Matrix', 'ROC-AUC curve'), 
    horizontal_spacing=0.05, 
    vertical_spacing=0.05  
)

fig.add_trace(
    go.Scatter(
        x=[0.5], y=[0.5], 
        text=[
            f"Test loss: {test_loss:.4f}<br>"
            f"Test accuracy: {test_acc:.4f}<br>"
            f"Test precision: {test_prec:.4f}<br>"
            f"Test recall: {test_recall:.4f}<br><br>"
            f"ROC-AUC: {roc_auc:.4f}<br><br>"
            f"Classification report<br>"
            f"{report.replace('\n','<br>')}"
        ],
        mode='text',
        showlegend=False,
    ),
    row=1, col=2  
)

fig.update_xaxes(visible=False, row=1, col=1)
fig.update_yaxes(visible=False, row=1, col=1)

# Loss and Precision
fig.add_trace(fig_metrics['data'][0], row=2, col=1)
fig.add_trace(fig_metrics['data'][1], row=2, col=1)
fig.add_trace(fig_metrics['data'][2], row=2, col=2)
fig.add_trace(fig_metrics['data'][3], row=2, col=2)

# Accuracy and Recall 
fig.add_trace(fig_metrics['data'][4], row=3, col=1)
fig.add_trace(fig_metrics['data'][5], row=3, col=1)
fig.add_trace(fig_metrics['data'][6], row=3, col=2)
fig.add_trace(fig_metrics['data'][7], row=3, col=2)

# Confusion Matrix and ROC-AUC curve 
fig.add_trace(fig_confMatrix['data'][0], row=4, col=1)
fig.add_trace(fig_rocauc['data'][0], row=4, col=2)
fig.add_trace(fig_rocauc['data'][1], row=4, col=2)

fig.update_layout(
    height=400*4, 
    width=1400, 
    # title_text=f"--- {MODEL.name} ---",
    showlegend=False, 
    margin=dict(l=10, r=10, t=50, b=10),  
)

fig.show()