In [8]:
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import os

In [9]:
base_path = 'E:/deepfake/deepfake_resnet50/Dataset/'

In [10]:
base_path = 'E:/deepfake/deepfake_resnet50/Dataset/'
input_shape = (128, 128, 3)
batch_size = 64
augmentation_factor = 2  # Augment the data by a factor of 2
small_train_steps = 1000 // (batch_size * augmentation_factor)

In [11]:
def build_model():
    densenet = ResNet50(
        weights='imagenet',
        include_top=False,
        input_shape=input_shape
    )
    model = Sequential([
        densenet,
        GlobalAveragePooling2D(),
        Dense(512, activation='relu'),
        BatchNormalization(),
        Dense(2, activation='softmax')
    ])
    model.compile(
        optimizer=Adam(lr=0.001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    return model

model = build_model()
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 resnet50 (Functional)       (None, 4, 4, 2048)        23587712  
                                                                 
 global_average_pooling2d (G  (None, 2048)             0         
 lobalAveragePooling2D)                                          
                                                                 
 dense (Dense)               (None, 512)               1049088   
                                                                 
 batch_normalization (BatchN  (None, 512)              2048      
 ormalization)                                                   
                                                                 
 dense_1 (Dense)             (None, 2)                 1026      
                                                                 
Total params: 24,639,874
Trainable params: 24,585,730
No

  super().__init__(name, **kwargs)


In [12]:
# Define an augmented data generator
augmented_generator = ImageDataGenerator(
    rescale=1. / 255.,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

In [13]:
# Create augmented training data generator
augmented_train_flow = augmented_generator.flow_from_directory(
    os.path.join(base_path, 'Train'),
    target_size=(128, 128),
    batch_size=batch_size,
    class_mode='categorical'
)

Found 140002 images belonging to 2 classes.


In [14]:
# Validation data generator (no augmentation)
valid_generator = ImageDataGenerator(rescale=1. / 255.)
valid_flow = valid_generator.flow_from_directory(
    os.path.join(base_path, 'Validation'),
    target_size=(128, 128),
    batch_size=batch_size,
    class_mode='categorical'
)


Found 39428 images belonging to 2 classes.


In [15]:
# Define a custom callback to compute and print confusion matrix
class PredictionCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        y_pred = self.model.predict(valid_flow[0][0])
        y_test = valid_flow[0][1]
        y_pred_labels = np.argmax(y_pred, axis=1)
        y_test_labels = np.argmax(y_test, axis=1)
        cfm = confusion_matrix(y_test_labels, y_pred_labels)
        print(cfm)
        print(y_pred[0], y_test[0])

In [17]:
# Calculate the number of validation steps
valid_steps = 10000 // batch_size

# Train the model with augmented data
history = model.fit(
    augmented_train_flow,
    epochs=10,
    steps_per_epoch=small_train_steps,
    validation_data=valid_flow,
    validation_steps=valid_steps,
    callbacks=[PredictionCallback()]
)

Epoch 1/10
[[ 0 27]
 [ 0 37]]
[5.8830176e-19 1.0000000e+00] [1. 0.]
Epoch 2/10
[[29  0]
 [35  0]]
[1. 0.] [1. 0.]
Epoch 3/10
[[ 0 34]
 [ 0 30]]
[0. 1.] [0. 1.]
Epoch 4/10
[[39  0]
 [25  0]]
[1. 0.] [0. 1.]
Epoch 5/10
[[ 0 27]
 [ 0 37]]
[0. 1.] [0. 1.]
Epoch 6/10
[[ 0 30]
 [ 0 34]]
[0. 1.] [1. 0.]
Epoch 7/10
[[ 0 36]
 [ 0 28]]
[0. 1.] [0. 1.]
Epoch 8/10
[[ 0 31]
 [ 0 33]]
[0. 1.] [1. 0.]
Epoch 9/10
[[33  0]
 [31  0]]
[0.61904 0.38096] [0. 1.]
Epoch 10/10
[[ 0 32]
 [ 0 32]]
[0.40732047 0.5926796 ] [1. 0.]


In [19]:
model.save('my_model.h5')